diff --git a/adapter/adapter.go b/adapter/adapter.go index ef8d4ee3..591f0517 100644 --- a/adapter/adapter.go +++ b/adapter/adapter.go @@ -177,14 +177,12 @@ func (p *Proxy) URLTest(ctx context.Context, url string, expectedStatus utils.In p.history.Pop() } - state, ok := p.extra.Load(url) - if !ok { - state = &internalProxyState{ - history: queue.New[C.DelayHistory](defaultHistoriesNum), - alive: atomic.NewBool(true), - } - p.extra.Store(url, state) + // Use LoadOrStore to avoid race condition + newState := &internalProxyState{ + history: queue.New[C.DelayHistory](defaultHistoriesNum), + alive: atomic.NewBool(true), } + state, _ := p.extra.LoadOrStore(url, newState) if !satisfied { record.Delay = 0 diff --git a/adapter/outboundgroup/groupbase.go b/adapter/outboundgroup/groupbase.go index 9e705677..dcf1311c 100644 --- a/adapter/outboundgroup/groupbase.go +++ b/adapter/outboundgroup/groupbase.go @@ -117,28 +117,43 @@ func (gb *GroupBase) GetProxies(touch bool) []C.Proxy { return gb.providerProxies } + // Pre-calculate total capacity for better allocation + var totalCapacity int + for _, pd := range gb.providers { + totalCapacity += len(pd.Proxies()) + } + var proxies []C.Proxy if len(gb.filterRegs) == 0 { + proxies = make([]C.Proxy, 0, totalCapacity) for _, pd := range gb.providers { proxies = append(proxies, pd.Proxies()...) } } else { + proxies = make([]C.Proxy, 0, totalCapacity) for _, pd := range gb.providers { if pd.VehicleType() == P.Compatible { // compatible provider unneeded filter proxies = append(proxies, pd.Proxies()...) continue } - var newProxies []C.Proxy - proxiesSet := map[string]struct{}{} - for _, filterReg := range gb.filterRegs { - for _, p := range pd.Proxies() { - name := p.Name() + pdProxies := pd.Proxies() + newProxies := make([]C.Proxy, 0, len(pdProxies)) + proxiesSet := make(map[string]struct{}, len(pdProxies)) + + // Optimize: loop through proxies first, then check all filters + for _, p := range pdProxies { + name := p.Name() + if _, exists := proxiesSet[name]; exists { + continue + } + + // Check if proxy matches any filter + for _, filterReg := range gb.filterRegs { if mat, _ := filterReg.MatchString(name); mat { - if _, ok := proxiesSet[name]; !ok { - proxiesSet[name] = struct{}{} - newProxies = append(newProxies, p) - } + proxiesSet[name] = struct{}{} + newProxies = append(newProxies, p) + break // Found a match, no need to check other filters } } } @@ -150,22 +165,23 @@ func (gb *GroupBase) GetProxies(touch bool) []C.Proxy { // Although the filter has been performed once in the previous process, // when there are multiple providers, the array needs to be reordered as a whole. if len(gb.providers) > 1 && len(gb.filterRegs) > 1 { - var newProxies []C.Proxy - proxiesSet := map[string]struct{}{} + newProxies := make([]C.Proxy, 0, len(proxies)) + proxiesSet := make(map[string]struct{}, len(proxies)) for _, filterReg := range gb.filterRegs { for _, p := range proxies { name := p.Name() + if _, exists := proxiesSet[name]; exists { + continue + } if mat, _ := filterReg.MatchString(name); mat { - if _, ok := proxiesSet[name]; !ok { - proxiesSet[name] = struct{}{} - newProxies = append(newProxies, p) - } + proxiesSet[name] = struct{}{} + newProxies = append(newProxies, p) } } } for _, p := range proxies { // add not matched proxies at the end name := p.Name() - if _, ok := proxiesSet[name]; !ok { + if _, exists := proxiesSet[name]; !exists { proxiesSet[name] = struct{}{} newProxies = append(newProxies, p) } @@ -174,7 +190,7 @@ func (gb *GroupBase) GetProxies(touch bool) []C.Proxy { } if len(gb.excludeFilterRegs) > 0 { - var newProxies []C.Proxy + newProxies := make([]C.Proxy, 0, len(proxies)) LOOP1: for _, p := range proxies { name := p.Name() diff --git a/dns/doq.go b/dns/doq.go index 6861bd4d..75a9917a 100644 --- a/dns/doq.go +++ b/dns/doq.go @@ -191,7 +191,8 @@ func (doq *dnsOverQUIC) exchangeQUIC(ctx context.Context, msg *D.Msg) (resp *D.M // AddPrefix adds a 2-byte prefix with the DNS message length. func AddPrefix(b []byte) (m []byte) { - m = make([]byte, 2+len(b)) + msgLen := 2 + len(b) + m = make([]byte, msgLen) binary.BigEndian.PutUint16(m, uint16(len(b))) copy(m[2:], b) diff --git a/main.go b/main.go index 6f21e811..7f40333b 100644 --- a/main.go +++ b/main.go @@ -63,6 +63,7 @@ func main() { net.DefaultResolver.PreferGo = true net.DefaultResolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { //panic("should never be called") + const maxStackSize = 65536 // 64KB max buffer size buf := make([]byte, 1024) for { n := runtime.Stack(buf, true) @@ -70,6 +71,12 @@ func main() { buf = buf[:n] break } + // Prevent unbounded growth + if len(buf) >= maxStackSize { + fmt.Fprintf(os.Stderr, "panic: should never be called (stack trace truncated at %d bytes)\n\n%s", maxStackSize, buf) + os.Exit(2) + return nil, nil + } buf = make([]byte, 2*len(buf)) } fmt.Fprintf(os.Stderr, "panic: should never be called\n\n%s", buf) // always print all goroutine stack diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 1b2f4136..156a36c2 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -64,6 +64,10 @@ var ( snifferDispatcher *sniffer.Dispatcher sniffingEnable = false + // Cache for ProxiesWithProviders + cachedProxiesWithProviders map[string]C.Proxy + proxiesWithProvidersValid bool + ruleUpdateCallback = utils.NewCallback[P.RuleProvider]() ) @@ -141,10 +145,10 @@ func Status() TunnelStatus { } func SetSniffing(b bool) { + configMux.Lock() + defer configMux.Unlock() if snifferDispatcher.Enable() { - configMux.Lock() sniffingEnable = b - configMux.Unlock() } } @@ -209,6 +213,24 @@ func Proxies() map[string]C.Proxy { } func ProxiesWithProviders() map[string]C.Proxy { + configMux.RLock() + // Return cached result if valid + if proxiesWithProvidersValid && cachedProxiesWithProviders != nil { + result := cachedProxiesWithProviders + configMux.RUnlock() + return result + } + configMux.RUnlock() + + // Build new cache + configMux.Lock() + defer configMux.Unlock() + + // Double-check after acquiring write lock + if proxiesWithProvidersValid && cachedProxiesWithProviders != nil { + return cachedProxiesWithProviders + } + allProxies := make(map[string]C.Proxy) for name, proxy := range proxies { allProxies[name] = proxy @@ -219,6 +241,11 @@ func ProxiesWithProviders() map[string]C.Proxy { allProxies[name] = proxy } } + + // Cache the result + cachedProxiesWithProviders = allProxies + proxiesWithProvidersValid = true + return allProxies } @@ -237,6 +264,9 @@ func UpdateProxies(newProxies map[string]C.Proxy, newProviders map[string]P.Prox configMux.Lock() proxies = newProxies providers = newProviders + // Invalidate cache when proxies/providers are updated + proxiesWithProvidersValid = false + cachedProxiesWithProviders = nil configMux.Unlock() } @@ -492,20 +522,11 @@ func handleUDPConn(packet C.PacketAdapter) { sender.Send(packet) // nonblocking } -func handleTCPConn(connCtx C.ConnContext) { - if !isHandle(connCtx.Metadata().Type) { - _ = connCtx.Conn().Close() - return - } - - defer func(conn net.Conn) { - _ = conn.Close() - }(connCtx.Conn()) - - metadata := connCtx.Metadata() +// prepareMetadata validates and prepares metadata for connection handling +func prepareMetadata(metadata *C.Metadata, conn C.Conn) bool { if !metadata.Valid() { log.Warnln("[Metadata] not valid: %#v", metadata) - return + return false } fixMetadata(metadata) // fix some metadata not set via metadata.SetRemoteAddr or metadata.SetRemoteAddress @@ -515,7 +536,6 @@ func handleTCPConn(connCtx C.ConnContext) { preHandleFailed = true } - conn := connCtx.Conn() conn.ResetPeeked() // reset before sniffer if sniffingEnable && snifferDispatcher.Enable() { // Try to sniff a domain when `preHandleMetadata` failed, this is usually @@ -530,20 +550,40 @@ func handleTCPConn(connCtx C.ConnContext) { if preHandleFailed { log.Debugln("[Metadata PreHandle] failed to sniff a domain for connection %s --> %s, give up", metadata.SourceDetail(), metadata.RemoteAddress()) + return false + } + + return true +} + +// peekConnection triggers TLS handshake if needed +func peekConnection(conn C.Conn) { + if !conn.Peeked() { + _ = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) + _, _ = conn.Peek(1) + _ = conn.SetReadDeadline(time.Time{}) + } +} + +func handleTCPConn(connCtx C.ConnContext) { + if !isHandle(connCtx.Metadata().Type) { + _ = connCtx.Conn().Close() return } - peekMutex := sync.Mutex{} - if !conn.Peeked() { - peekMutex.Lock() - go func() { - defer peekMutex.Unlock() - _ = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) - _, _ = conn.Peek(1) - _ = conn.SetReadDeadline(time.Time{}) - }() + defer func(conn net.Conn) { + _ = conn.Close() + }(connCtx.Conn()) + + metadata := connCtx.Metadata() + conn := connCtx.Conn() + + if !prepareMetadata(metadata, conn) { + return } + peekConnection(conn) + proxy, rule, err := resolveMetadata(metadata) if err != nil { log.Warnln("[Metadata] parse failed: %s", err.Error()) @@ -627,7 +667,7 @@ func logMetadataErr(metadata *C.Metadata, rule C.Rule, proxy C.ProxyAdapter, err func logMetadata(metadata *C.Metadata, rule C.Rule, remoteConn C.Connection) { switch { case metadata.SpecialProxy != "": - log.Infoln("[%s] %s --> %s using %s", strings.ToUpper(metadata.NetWork.String()), metadata.SourceDetail(), metadata.RemoteAddress(), remoteConn.Chains().String()) + log.Infoln("[%s] %s --> %s using %s", strings.ToUpper(metadata.NetWork.String()), metadata.SourceDetail(), metadata.RemoteAddress(), metadata.SpecialProxy) case rule != nil: if rule.Payload() != "" { log.Infoln("[%s] %s --> %s match %s using %s", strings.ToUpper(metadata.NetWork.String()), metadata.SourceDetail(), metadata.RemoteAddress(), fmt.Sprintf("%s(%s)", rule.RuleType().String(), rule.Payload()), remoteConn.Chains().String()) @@ -639,7 +679,7 @@ func logMetadata(metadata *C.Metadata, rule C.Rule, remoteConn C.Connection) { case mode == Direct: log.Infoln("[%s] %s --> %s using DIRECT", strings.ToUpper(metadata.NetWork.String()), metadata.SourceDetail(), metadata.RemoteAddress()) default: - log.Infoln("[%s] %s --> %s doesn't match any rule using %s", strings.ToUpper(metadata.NetWork.String()), metadata.SourceDetail(), metadata.RemoteAddress(), remoteConn.Chains().String()) + log.Infoln("[%s] %s --> %s doesn't match any rule using %s", strings.ToUpper(metadata.NetWork.String()), metadata.SourceDetail(), metadata.RemoteAddress(), remoteConn.Chains().Last()) } }