From 3b4f07cd4bb46943ed73f11679845073a88b9325 Mon Sep 17 00:00:00 2001 From: Xi Xu Date: Thu, 4 Dec 2025 15:39:48 +0800 Subject: [PATCH] feat: optimize proxy management and connection handling Improves thread safety in adapter URLTest by using LoadOrStore, optimizes proxy filtering and allocation in outbound group, adds caching for ProxiesWithProviders in tunnel, and refactors TCP connection handling for better metadata validation and TLS handshake triggering. Also adds stack trace size limit in main.go and minor improvements in DNS message prefixing. --- adapter/adapter.go | 12 ++-- adapter/outboundgroup/groupbase.go | 50 ++++++++++------ dns/doq.go | 3 +- main.go | 7 +++ tunnel/tunnel.go | 92 +++++++++++++++++++++--------- 5 files changed, 113 insertions(+), 51 deletions(-) diff --git a/adapter/adapter.go b/adapter/adapter.go index ef8d4ee3..591f0517 100644 --- a/adapter/adapter.go +++ b/adapter/adapter.go @@ -177,14 +177,12 @@ func (p *Proxy) URLTest(ctx context.Context, url string, expectedStatus utils.In p.history.Pop() } - state, ok := p.extra.Load(url) - if !ok { - state = &internalProxyState{ - history: queue.New[C.DelayHistory](defaultHistoriesNum), - alive: atomic.NewBool(true), - } - p.extra.Store(url, state) + // Use LoadOrStore to avoid race condition + newState := &internalProxyState{ + history: queue.New[C.DelayHistory](defaultHistoriesNum), + alive: atomic.NewBool(true), } + state, _ := p.extra.LoadOrStore(url, newState) if !satisfied { record.Delay = 0 diff --git a/adapter/outboundgroup/groupbase.go b/adapter/outboundgroup/groupbase.go index 9e705677..dcf1311c 100644 --- a/adapter/outboundgroup/groupbase.go +++ b/adapter/outboundgroup/groupbase.go @@ -117,28 +117,43 @@ func (gb *GroupBase) GetProxies(touch bool) []C.Proxy { return gb.providerProxies } + // Pre-calculate total capacity for better allocation + var totalCapacity int + for _, pd := range gb.providers { + totalCapacity += len(pd.Proxies()) + } + var proxies []C.Proxy if len(gb.filterRegs) == 0 { + proxies = make([]C.Proxy, 0, totalCapacity) for _, pd := range gb.providers { proxies = append(proxies, pd.Proxies()...) } } else { + proxies = make([]C.Proxy, 0, totalCapacity) for _, pd := range gb.providers { if pd.VehicleType() == P.Compatible { // compatible provider unneeded filter proxies = append(proxies, pd.Proxies()...) continue } - var newProxies []C.Proxy - proxiesSet := map[string]struct{}{} - for _, filterReg := range gb.filterRegs { - for _, p := range pd.Proxies() { - name := p.Name() + pdProxies := pd.Proxies() + newProxies := make([]C.Proxy, 0, len(pdProxies)) + proxiesSet := make(map[string]struct{}, len(pdProxies)) + + // Optimize: loop through proxies first, then check all filters + for _, p := range pdProxies { + name := p.Name() + if _, exists := proxiesSet[name]; exists { + continue + } + + // Check if proxy matches any filter + for _, filterReg := range gb.filterRegs { if mat, _ := filterReg.MatchString(name); mat { - if _, ok := proxiesSet[name]; !ok { - proxiesSet[name] = struct{}{} - newProxies = append(newProxies, p) - } + proxiesSet[name] = struct{}{} + newProxies = append(newProxies, p) + break // Found a match, no need to check other filters } } } @@ -150,22 +165,23 @@ func (gb *GroupBase) GetProxies(touch bool) []C.Proxy { // Although the filter has been performed once in the previous process, // when there are multiple providers, the array needs to be reordered as a whole. if len(gb.providers) > 1 && len(gb.filterRegs) > 1 { - var newProxies []C.Proxy - proxiesSet := map[string]struct{}{} + newProxies := make([]C.Proxy, 0, len(proxies)) + proxiesSet := make(map[string]struct{}, len(proxies)) for _, filterReg := range gb.filterRegs { for _, p := range proxies { name := p.Name() + if _, exists := proxiesSet[name]; exists { + continue + } if mat, _ := filterReg.MatchString(name); mat { - if _, ok := proxiesSet[name]; !ok { - proxiesSet[name] = struct{}{} - newProxies = append(newProxies, p) - } + proxiesSet[name] = struct{}{} + newProxies = append(newProxies, p) } } } for _, p := range proxies { // add not matched proxies at the end name := p.Name() - if _, ok := proxiesSet[name]; !ok { + if _, exists := proxiesSet[name]; !exists { proxiesSet[name] = struct{}{} newProxies = append(newProxies, p) } @@ -174,7 +190,7 @@ func (gb *GroupBase) GetProxies(touch bool) []C.Proxy { } if len(gb.excludeFilterRegs) > 0 { - var newProxies []C.Proxy + newProxies := make([]C.Proxy, 0, len(proxies)) LOOP1: for _, p := range proxies { name := p.Name() diff --git a/dns/doq.go b/dns/doq.go index 6861bd4d..75a9917a 100644 --- a/dns/doq.go +++ b/dns/doq.go @@ -191,7 +191,8 @@ func (doq *dnsOverQUIC) exchangeQUIC(ctx context.Context, msg *D.Msg) (resp *D.M // AddPrefix adds a 2-byte prefix with the DNS message length. func AddPrefix(b []byte) (m []byte) { - m = make([]byte, 2+len(b)) + msgLen := 2 + len(b) + m = make([]byte, msgLen) binary.BigEndian.PutUint16(m, uint16(len(b))) copy(m[2:], b) diff --git a/main.go b/main.go index 6f21e811..7f40333b 100644 --- a/main.go +++ b/main.go @@ -63,6 +63,7 @@ func main() { net.DefaultResolver.PreferGo = true net.DefaultResolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { //panic("should never be called") + const maxStackSize = 65536 // 64KB max buffer size buf := make([]byte, 1024) for { n := runtime.Stack(buf, true) @@ -70,6 +71,12 @@ func main() { buf = buf[:n] break } + // Prevent unbounded growth + if len(buf) >= maxStackSize { + fmt.Fprintf(os.Stderr, "panic: should never be called (stack trace truncated at %d bytes)\n\n%s", maxStackSize, buf) + os.Exit(2) + return nil, nil + } buf = make([]byte, 2*len(buf)) } fmt.Fprintf(os.Stderr, "panic: should never be called\n\n%s", buf) // always print all goroutine stack diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 1b2f4136..156a36c2 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -64,6 +64,10 @@ var ( snifferDispatcher *sniffer.Dispatcher sniffingEnable = false + // Cache for ProxiesWithProviders + cachedProxiesWithProviders map[string]C.Proxy + proxiesWithProvidersValid bool + ruleUpdateCallback = utils.NewCallback[P.RuleProvider]() ) @@ -141,10 +145,10 @@ func Status() TunnelStatus { } func SetSniffing(b bool) { + configMux.Lock() + defer configMux.Unlock() if snifferDispatcher.Enable() { - configMux.Lock() sniffingEnable = b - configMux.Unlock() } } @@ -209,6 +213,24 @@ func Proxies() map[string]C.Proxy { } func ProxiesWithProviders() map[string]C.Proxy { + configMux.RLock() + // Return cached result if valid + if proxiesWithProvidersValid && cachedProxiesWithProviders != nil { + result := cachedProxiesWithProviders + configMux.RUnlock() + return result + } + configMux.RUnlock() + + // Build new cache + configMux.Lock() + defer configMux.Unlock() + + // Double-check after acquiring write lock + if proxiesWithProvidersValid && cachedProxiesWithProviders != nil { + return cachedProxiesWithProviders + } + allProxies := make(map[string]C.Proxy) for name, proxy := range proxies { allProxies[name] = proxy @@ -219,6 +241,11 @@ func ProxiesWithProviders() map[string]C.Proxy { allProxies[name] = proxy } } + + // Cache the result + cachedProxiesWithProviders = allProxies + proxiesWithProvidersValid = true + return allProxies } @@ -237,6 +264,9 @@ func UpdateProxies(newProxies map[string]C.Proxy, newProviders map[string]P.Prox configMux.Lock() proxies = newProxies providers = newProviders + // Invalidate cache when proxies/providers are updated + proxiesWithProvidersValid = false + cachedProxiesWithProviders = nil configMux.Unlock() } @@ -492,20 +522,11 @@ func handleUDPConn(packet C.PacketAdapter) { sender.Send(packet) // nonblocking } -func handleTCPConn(connCtx C.ConnContext) { - if !isHandle(connCtx.Metadata().Type) { - _ = connCtx.Conn().Close() - return - } - - defer func(conn net.Conn) { - _ = conn.Close() - }(connCtx.Conn()) - - metadata := connCtx.Metadata() +// prepareMetadata validates and prepares metadata for connection handling +func prepareMetadata(metadata *C.Metadata, conn C.Conn) bool { if !metadata.Valid() { log.Warnln("[Metadata] not valid: %#v", metadata) - return + return false } fixMetadata(metadata) // fix some metadata not set via metadata.SetRemoteAddr or metadata.SetRemoteAddress @@ -515,7 +536,6 @@ func handleTCPConn(connCtx C.ConnContext) { preHandleFailed = true } - conn := connCtx.Conn() conn.ResetPeeked() // reset before sniffer if sniffingEnable && snifferDispatcher.Enable() { // Try to sniff a domain when `preHandleMetadata` failed, this is usually @@ -530,20 +550,40 @@ func handleTCPConn(connCtx C.ConnContext) { if preHandleFailed { log.Debugln("[Metadata PreHandle] failed to sniff a domain for connection %s --> %s, give up", metadata.SourceDetail(), metadata.RemoteAddress()) + return false + } + + return true +} + +// peekConnection triggers TLS handshake if needed +func peekConnection(conn C.Conn) { + if !conn.Peeked() { + _ = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) + _, _ = conn.Peek(1) + _ = conn.SetReadDeadline(time.Time{}) + } +} + +func handleTCPConn(connCtx C.ConnContext) { + if !isHandle(connCtx.Metadata().Type) { + _ = connCtx.Conn().Close() return } - peekMutex := sync.Mutex{} - if !conn.Peeked() { - peekMutex.Lock() - go func() { - defer peekMutex.Unlock() - _ = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) - _, _ = conn.Peek(1) - _ = conn.SetReadDeadline(time.Time{}) - }() + defer func(conn net.Conn) { + _ = conn.Close() + }(connCtx.Conn()) + + metadata := connCtx.Metadata() + conn := connCtx.Conn() + + if !prepareMetadata(metadata, conn) { + return } + peekConnection(conn) + proxy, rule, err := resolveMetadata(metadata) if err != nil { log.Warnln("[Metadata] parse failed: %s", err.Error()) @@ -627,7 +667,7 @@ func logMetadataErr(metadata *C.Metadata, rule C.Rule, proxy C.ProxyAdapter, err func logMetadata(metadata *C.Metadata, rule C.Rule, remoteConn C.Connection) { switch { case metadata.SpecialProxy != "": - log.Infoln("[%s] %s --> %s using %s", strings.ToUpper(metadata.NetWork.String()), metadata.SourceDetail(), metadata.RemoteAddress(), remoteConn.Chains().String()) + log.Infoln("[%s] %s --> %s using %s", strings.ToUpper(metadata.NetWork.String()), metadata.SourceDetail(), metadata.RemoteAddress(), metadata.SpecialProxy) case rule != nil: if rule.Payload() != "" { log.Infoln("[%s] %s --> %s match %s using %s", strings.ToUpper(metadata.NetWork.String()), metadata.SourceDetail(), metadata.RemoteAddress(), fmt.Sprintf("%s(%s)", rule.RuleType().String(), rule.Payload()), remoteConn.Chains().String()) @@ -639,7 +679,7 @@ func logMetadata(metadata *C.Metadata, rule C.Rule, remoteConn C.Connection) { case mode == Direct: log.Infoln("[%s] %s --> %s using DIRECT", strings.ToUpper(metadata.NetWork.String()), metadata.SourceDetail(), metadata.RemoteAddress()) default: - log.Infoln("[%s] %s --> %s doesn't match any rule using %s", strings.ToUpper(metadata.NetWork.String()), metadata.SourceDetail(), metadata.RemoteAddress(), remoteConn.Chains().String()) + log.Infoln("[%s] %s --> %s doesn't match any rule using %s", strings.ToUpper(metadata.NetWork.String()), metadata.SourceDetail(), metadata.RemoteAddress(), remoteConn.Chains().Last()) } }