From 082bcec281ecdb2cfb27730067eaec1c23429f41 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Thu, 12 Jun 2025 00:27:51 +0800 Subject: [PATCH] chore: apply find process mode in direct/global mode --- component/process/find_process_mode.go | 63 ++++++------ tunnel/tunnel.go | 131 +++++++++++++------------ 2 files changed, 97 insertions(+), 97 deletions(-) diff --git a/component/process/find_process_mode.go b/component/process/find_process_mode.go index 06618cef..8add9df8 100644 --- a/component/process/find_process_mode.go +++ b/component/process/find_process_mode.go @@ -1,57 +1,52 @@ package process import ( - "encoding/json" "errors" "strings" ) const ( - FindProcessAlways = "always" - FindProcessStrict = "strict" - FindProcessOff = "off" + FindProcessStrict FindProcessMode = iota + FindProcessAlways + FindProcessOff ) var ( - validModes = map[string]struct{}{ - FindProcessAlways: {}, - FindProcessOff: {}, - FindProcessStrict: {}, + validModes = map[string]FindProcessMode{ + FindProcessStrict.String(): FindProcessStrict, + FindProcessAlways.String(): FindProcessAlways, + FindProcessOff.String(): FindProcessOff, } ) -type FindProcessMode string +type FindProcessMode int -func (m FindProcessMode) Always() bool { - return m == FindProcessAlways -} - -func (m FindProcessMode) Off() bool { - return m == FindProcessOff -} - -func (m *FindProcessMode) UnmarshalYAML(unmarshal func(any) error) error { - var tp string - if err := unmarshal(&tp); err != nil { - return err - } - return m.Set(tp) -} - -func (m *FindProcessMode) UnmarshalJSON(data []byte) error { - var tp string - if err := json.Unmarshal(data, &tp); err != nil { - return err - } - return m.Set(tp) +// UnmarshalText unserialize FindProcessMode +func (m *FindProcessMode) UnmarshalText(data []byte) error { + return m.Set(string(data)) } func (m *FindProcessMode) Set(value string) error { - mode := strings.ToLower(value) - _, exist := validModes[mode] + mode, exist := validModes[strings.ToLower(value)] if !exist { return errors.New("invalid find process mode") } - *m = FindProcessMode(mode) + *m = mode return nil } + +// MarshalText serialize FindProcessMode +func (m FindProcessMode) MarshalText() ([]byte, error) { + return []byte(m.String()), nil +} + +func (m FindProcessMode) String() string { + switch m { + case FindProcessAlways: + return "always" + case FindProcessOff: + return "off" + default: + return "strict" + } +} diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 1314e540..f6d87670 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "github.com/metacubex/mihomo/common/atomic" N "github.com/metacubex/mihomo/common/net" "github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/component/loopback" @@ -58,7 +59,7 @@ var ( // default timeout for UDP session udpTimeout = 60 * time.Second - findProcessMode P.FindProcessMode + findProcessMode = atomic.NewTypedValue(P.FindProcessStrict) fakeIPRange netip.Prefix @@ -273,13 +274,13 @@ func SetMode(m TunnelMode) { } func FindProcessMode() P.FindProcessMode { - return findProcessMode + return findProcessMode.Load() } // SetFindProcessMode replace SetAlwaysFindProcess // always find process info if legacyAlways = true or mode.Always() = true, may be increase many memory func SetFindProcessMode(mode P.FindProcessMode) { - findProcessMode = mode + findProcessMode.Store(mode) } func isHandle(t C.Type) bool { @@ -337,6 +338,68 @@ func resolveMetadata(metadata *C.Metadata) (proxy C.Proxy, rule C.Rule, err erro } return } + var ( + resolved bool + attemptProcessLookup = metadata.Type != C.INNER + ) + + if node, ok := resolver.DefaultHosts.Search(metadata.Host, false); ok { + metadata.DstIP, _ = node.RandIP() + resolved = true + } + + helper := C.RuleMatchHelper{ + ResolveIP: func() { + if !resolved && metadata.Host != "" && !metadata.Resolved() { + ctx, cancel := context.WithTimeout(context.Background(), resolver.DefaultDNSTimeout) + defer cancel() + ip, err := resolver.ResolveIP(ctx, metadata.Host) + if err != nil { + log.Debugln("[DNS] resolve %s error: %s", metadata.Host, err.Error()) + } else { + log.Debugln("[DNS] %s --> %s", metadata.Host, ip.String()) + metadata.DstIP = ip + } + resolved = true + } + }, + FindProcess: func() { + if attemptProcessLookup { + attemptProcessLookup = false + if !features.CMFA { + // normal check for process + uid, path, err := P.FindProcessName(metadata.NetWork.String(), metadata.SrcIP, int(metadata.SrcPort)) + if err != nil { + log.Debugln("[Process] find process error for %s: %v", metadata.String(), err) + } else { + metadata.Process = filepath.Base(path) + metadata.ProcessPath = path + metadata.Uid = uid + + if pkg, err := P.FindPackageName(metadata); err == nil { // for android (not CMFA) package names + metadata.Process = pkg + } + } + } else { + // check package names + pkg, err := P.FindPackageName(metadata) + if err != nil { + log.Debugln("[Process] find process error for %s: %v", metadata.String(), err) + } else { + metadata.Process = pkg + } + } + } + }, + } + + switch FindProcessMode() { + case P.FindProcessAlways: + helper.FindProcess() + helper.FindProcess = nil + case P.FindProcessOff: + helper.FindProcess = nil + } switch mode { case Direct: @@ -345,7 +408,7 @@ func resolveMetadata(metadata *C.Metadata) (proxy C.Proxy, rule C.Rule, err erro proxy = proxies["GLOBAL"] // Rule default: - proxy, rule, err = match(metadata) + proxy, rule, err = match(metadata, helper) } return } @@ -590,67 +653,9 @@ func logMetadata(metadata *C.Metadata, rule C.Rule, remoteConn C.Connection) { } } -func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { +func match(metadata *C.Metadata, helper C.RuleMatchHelper) (C.Proxy, C.Rule, error) { configMux.RLock() defer configMux.RUnlock() - var ( - resolved bool - attemptProcessLookup = metadata.Type != C.INNER - ) - - if node, ok := resolver.DefaultHosts.Search(metadata.Host, false); ok { - metadata.DstIP, _ = node.RandIP() - resolved = true - } - - helper := C.RuleMatchHelper{ - ResolveIP: func() { - if !resolved && metadata.Host != "" && !metadata.Resolved() { - ctx, cancel := context.WithTimeout(context.Background(), resolver.DefaultDNSTimeout) - defer cancel() - ip, err := resolver.ResolveIP(ctx, metadata.Host) - if err != nil { - log.Debugln("[DNS] resolve %s error: %s", metadata.Host, err.Error()) - } else { - log.Debugln("[DNS] %s --> %s", metadata.Host, ip.String()) - metadata.DstIP = ip - } - resolved = true - } - }, - FindProcess: func() { - if attemptProcessLookup && !findProcessMode.Off() { - attemptProcessLookup = false - if !features.CMFA { - // normal check for process - uid, path, err := P.FindProcessName(metadata.NetWork.String(), metadata.SrcIP, int(metadata.SrcPort)) - if err != nil { - log.Debugln("[Process] find process error for %s: %v", metadata.String(), err) - } else { - metadata.Process = filepath.Base(path) - metadata.ProcessPath = path - metadata.Uid = uid - - if pkg, err := P.FindPackageName(metadata); err == nil { // for android (not CMFA) package names - metadata.Process = pkg - } - } - } else { - // check package names - pkg, err := P.FindPackageName(metadata) - if err != nil { - log.Debugln("[Process] find process error for %s: %v", metadata.String(), err) - } else { - metadata.Process = pkg - } - } - } - }, - } - - if findProcessMode.Always() { - helper.FindProcess() - } for _, rule := range getRules(metadata) { if matched, ada := rule.Match(metadata, helper); matched {