From ae7967f66296b791a5eff0689b22c632c04c4942 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Tue, 10 Jun 2025 20:11:50 +0800 Subject: [PATCH] chore: the resolve and findProcess behaviors of Logic and SubRules follow the order and needs of the internal rules --- constant/provider/interface.go | 4 +- constant/rule.go | 9 ++-- rules/common/base.go | 8 ---- rules/common/domain.go | 2 +- rules/common/domain_keyword.go | 2 +- rules/common/domain_regex.go | 2 +- rules/common/domain_suffix.go | 2 +- rules/common/dscp.go | 2 +- rules/common/final.go | 2 +- rules/common/geoip.go | 10 ++--- rules/common/geosite.go | 2 +- rules/common/in_name.go | 2 +- rules/common/in_type.go | 2 +- rules/common/in_user.go | 2 +- rules/common/ipasn.go | 10 ++--- rules/common/ipcidr.go | 10 ++--- rules/common/ipsuffix.go | 10 ++--- rules/common/network_type.go | 2 +- rules/common/port.go | 2 +- rules/common/process.go | 9 ++-- rules/common/uid.go | 9 ++-- rules/logic/logic.go | 50 ++++----------------- rules/logic_test/logic_test.go | 8 ++-- rules/provider/classical_strategy.go | 34 +++------------ rules/provider/domain_strategy.go | 10 +---- rules/provider/ipcidr_strategy.go | 25 +++-------- rules/provider/provider.go | 16 ++----- rules/provider/rule_set.go | 41 +++++------------- tunnel/tunnel.go | 65 +++++++++++++++------------- 29 files changed, 121 insertions(+), 231 deletions(-) diff --git a/constant/provider/interface.go b/constant/provider/interface.go index 8c6ee6f8..b4e24b59 100644 --- a/constant/provider/interface.go +++ b/constant/provider/interface.go @@ -91,9 +91,7 @@ type RuleProvider interface { Provider Behavior() RuleBehavior Count() int - Match(*constant.Metadata) bool - ShouldResolveIP() bool - ShouldFindProcess() bool + Match(metadata *constant.Metadata, helper constant.RuleMatchHelper) bool Strategy() any } diff --git a/constant/rule.go b/constant/rule.go index 31702ddc..5bdb5a5e 100644 --- a/constant/rule.go +++ b/constant/rule.go @@ -111,14 +111,17 @@ func (rt RuleType) String() string { type Rule interface { RuleType() RuleType - Match(metadata *Metadata) (bool, string) + Match(metadata *Metadata, helper RuleMatchHelper) (bool, string) Adapter() string Payload() string - ShouldResolveIP() bool - ShouldFindProcess() bool ProviderNames() []string } +type RuleMatchHelper struct { + ResolveIP func() + FindProcess func() +} + type RuleGroup interface { Rule GetRecodeSize() int diff --git a/rules/common/base.go b/rules/common/base.go index 1abbe72c..0baff49d 100644 --- a/rules/common/base.go +++ b/rules/common/base.go @@ -21,14 +21,6 @@ var ( type Base struct { } -func (b *Base) ShouldFindProcess() bool { - return false -} - -func (b *Base) ShouldResolveIP() bool { - return false -} - func (b *Base) ProviderNames() []string { return nil } func ParseParams(params []string) (isSrc bool, noResolve bool) { diff --git a/rules/common/domain.go b/rules/common/domain.go index 306eb65f..7feb6a19 100644 --- a/rules/common/domain.go +++ b/rules/common/domain.go @@ -17,7 +17,7 @@ func (d *Domain) RuleType() C.RuleType { return C.Domain } -func (d *Domain) Match(metadata *C.Metadata) (bool, string) { +func (d *Domain) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { return metadata.RuleHost() == d.domain, d.adapter } diff --git a/rules/common/domain_keyword.go b/rules/common/domain_keyword.go index 9d6f1c15..813e07a0 100644 --- a/rules/common/domain_keyword.go +++ b/rules/common/domain_keyword.go @@ -17,7 +17,7 @@ func (dk *DomainKeyword) RuleType() C.RuleType { return C.DomainKeyword } -func (dk *DomainKeyword) Match(metadata *C.Metadata) (bool, string) { +func (dk *DomainKeyword) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { domain := metadata.RuleHost() return strings.Contains(domain, dk.keyword), dk.adapter } diff --git a/rules/common/domain_regex.go b/rules/common/domain_regex.go index d214a772..62957b4b 100644 --- a/rules/common/domain_regex.go +++ b/rules/common/domain_regex.go @@ -16,7 +16,7 @@ func (dr *DomainRegex) RuleType() C.RuleType { return C.DomainRegex } -func (dr *DomainRegex) Match(metadata *C.Metadata) (bool, string) { +func (dr *DomainRegex) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { domain := metadata.RuleHost() match, _ := dr.regex.MatchString(domain) return match, dr.adapter diff --git a/rules/common/domain_suffix.go b/rules/common/domain_suffix.go index c5b87208..54c32014 100644 --- a/rules/common/domain_suffix.go +++ b/rules/common/domain_suffix.go @@ -17,7 +17,7 @@ func (ds *DomainSuffix) RuleType() C.RuleType { return C.DomainSuffix } -func (ds *DomainSuffix) Match(metadata *C.Metadata) (bool, string) { +func (ds *DomainSuffix) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { domain := metadata.RuleHost() return strings.HasSuffix(domain, "."+ds.suffix) || domain == ds.suffix, ds.adapter } diff --git a/rules/common/dscp.go b/rules/common/dscp.go index c839b20d..f3dcca8a 100644 --- a/rules/common/dscp.go +++ b/rules/common/dscp.go @@ -18,7 +18,7 @@ func (d *DSCP) RuleType() C.RuleType { return C.DSCP } -func (d *DSCP) Match(metadata *C.Metadata) (bool, string) { +func (d *DSCP) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { return d.ranges.Check(metadata.DSCP), d.adapter } diff --git a/rules/common/final.go b/rules/common/final.go index d3a415a0..92c46c7b 100644 --- a/rules/common/final.go +++ b/rules/common/final.go @@ -13,7 +13,7 @@ func (f *Match) RuleType() C.RuleType { return C.MATCH } -func (f *Match) Match(metadata *C.Metadata) (bool, string) { +func (f *Match) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { return true, f.adapter } diff --git a/rules/common/geoip.go b/rules/common/geoip.go index 61fae504..003ab90a 100644 --- a/rules/common/geoip.go +++ b/rules/common/geoip.go @@ -33,7 +33,11 @@ func (g *GEOIP) RuleType() C.RuleType { return C.GEOIP } -func (g *GEOIP) Match(metadata *C.Metadata) (bool, string) { +func (g *GEOIP) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { + if !g.noResolveIP && !g.isSourceIP && helper.ResolveIP != nil { + helper.ResolveIP() + } + ip := metadata.DstIP if g.isSourceIP { ip = metadata.SrcIP @@ -161,10 +165,6 @@ func (g *GEOIP) Payload() string { return g.country } -func (g *GEOIP) ShouldResolveIP() bool { - return !g.noResolveIP -} - func (g *GEOIP) GetCountry() string { return g.country } diff --git a/rules/common/geosite.go b/rules/common/geosite.go index 851bc8a4..f61d68a8 100644 --- a/rules/common/geosite.go +++ b/rules/common/geosite.go @@ -22,7 +22,7 @@ func (gs *GEOSITE) RuleType() C.RuleType { return C.GEOSITE } -func (gs *GEOSITE) Match(metadata *C.Metadata) (bool, string) { +func (gs *GEOSITE) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { return gs.MatchDomain(metadata.RuleHost()), gs.adapter } diff --git a/rules/common/in_name.go b/rules/common/in_name.go index 9b14ef6a..9986d8a6 100644 --- a/rules/common/in_name.go +++ b/rules/common/in_name.go @@ -13,7 +13,7 @@ type InName struct { payload string } -func (u *InName) Match(metadata *C.Metadata) (bool, string) { +func (u *InName) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { for _, name := range u.names { if metadata.InName == name { return true, u.adapter diff --git a/rules/common/in_type.go b/rules/common/in_type.go index fc73b208..0f11468f 100644 --- a/rules/common/in_type.go +++ b/rules/common/in_type.go @@ -13,7 +13,7 @@ type InType struct { payload string } -func (u *InType) Match(metadata *C.Metadata) (bool, string) { +func (u *InType) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { for _, tp := range u.types { if metadata.Type == tp { return true, u.adapter diff --git a/rules/common/in_user.go b/rules/common/in_user.go index ebe881af..6b786433 100644 --- a/rules/common/in_user.go +++ b/rules/common/in_user.go @@ -13,7 +13,7 @@ type InUser struct { payload string } -func (u *InUser) Match(metadata *C.Metadata) (bool, string) { +func (u *InUser) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { for _, user := range u.users { if metadata.InUser == user { return true, u.adapter diff --git a/rules/common/ipasn.go b/rules/common/ipasn.go index 7d554103..6adefd05 100644 --- a/rules/common/ipasn.go +++ b/rules/common/ipasn.go @@ -15,7 +15,11 @@ type ASN struct { isSourceIP bool } -func (a *ASN) Match(metadata *C.Metadata) (bool, string) { +func (a *ASN) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { + if !a.noResolveIP && !a.isSourceIP && helper.ResolveIP != nil { + helper.ResolveIP() + } + ip := metadata.DstIP if a.isSourceIP { ip = metadata.SrcIP @@ -49,10 +53,6 @@ func (a *ASN) Payload() string { return a.asn } -func (a *ASN) ShouldResolveIP() bool { - return !a.noResolveIP -} - func (a *ASN) GetASN() string { return a.asn } diff --git a/rules/common/ipcidr.go b/rules/common/ipcidr.go index 9c159502..88d6ccc5 100644 --- a/rules/common/ipcidr.go +++ b/rules/common/ipcidr.go @@ -35,7 +35,11 @@ func (i *IPCIDR) RuleType() C.RuleType { return C.IPCIDR } -func (i *IPCIDR) Match(metadata *C.Metadata) (bool, string) { +func (i *IPCIDR) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { + if !i.noResolveIP && !i.isSourceIP && helper.ResolveIP != nil { + helper.ResolveIP() + } + ip := metadata.DstIP if i.isSourceIP { ip = metadata.SrcIP @@ -51,10 +55,6 @@ func (i *IPCIDR) Payload() string { return i.ipnet.String() } -func (i *IPCIDR) ShouldResolveIP() bool { - return !i.noResolveIP -} - func NewIPCIDR(s string, adapter string, opts ...IPCIDROption) (*IPCIDR, error) { ipnet, err := netip.ParsePrefix(s) if err != nil { diff --git a/rules/common/ipsuffix.go b/rules/common/ipsuffix.go index 3251faf8..98f07acd 100644 --- a/rules/common/ipsuffix.go +++ b/rules/common/ipsuffix.go @@ -22,7 +22,11 @@ func (is *IPSuffix) RuleType() C.RuleType { return C.IPSuffix } -func (is *IPSuffix) Match(metadata *C.Metadata) (bool, string) { +func (is *IPSuffix) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { + if !is.noResolveIP && !is.isSourceIP && helper.ResolveIP != nil { + helper.ResolveIP() + } + ip := metadata.DstIP if is.isSourceIP { ip = metadata.SrcIP @@ -57,10 +61,6 @@ func (is *IPSuffix) Payload() string { return is.payload } -func (is *IPSuffix) ShouldResolveIP() bool { - return !is.noResolveIP -} - func NewIPSuffix(payload, adapter string, isSrc, noResolveIP bool) (*IPSuffix, error) { ipnet, err := netip.ParsePrefix(payload) if err != nil { diff --git a/rules/common/network_type.go b/rules/common/network_type.go index 83a332d8..d7003cd8 100644 --- a/rules/common/network_type.go +++ b/rules/common/network_type.go @@ -34,7 +34,7 @@ func (n *NetworkType) RuleType() C.RuleType { return C.Network } -func (n *NetworkType) Match(metadata *C.Metadata) (bool, string) { +func (n *NetworkType) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { return n.network == metadata.NetWork, n.adapter } diff --git a/rules/common/port.go b/rules/common/port.go index d3f6e1b4..3e9d8806 100644 --- a/rules/common/port.go +++ b/rules/common/port.go @@ -19,7 +19,7 @@ func (p *Port) RuleType() C.RuleType { return p.ruleType } -func (p *Port) Match(metadata *C.Metadata) (bool, string) { +func (p *Port) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { targetPort := metadata.DstPort switch p.ruleType { case C.InPort: diff --git a/rules/common/process.go b/rules/common/process.go index 8932e946..c7e69574 100644 --- a/rules/common/process.go +++ b/rules/common/process.go @@ -30,7 +30,10 @@ func (ps *Process) RuleType() C.RuleType { return C.ProcessPath } -func (ps *Process) Match(metadata *C.Metadata) (bool, string) { +func (ps *Process) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { + if helper.FindProcess != nil { + helper.FindProcess() + } if ps.nameOnly { if ps.regexp != nil { match, _ := ps.regexp.MatchString(metadata.Process) @@ -54,10 +57,6 @@ func (ps *Process) Payload() string { return ps.process } -func (ps *Process) ShouldFindProcess() bool { - return true -} - func NewProcess(process string, adapter string, nameOnly bool, regex bool) (*Process, error) { var r *regexp2.Regexp var err error diff --git a/rules/common/uid.go b/rules/common/uid.go index c80632b0..31365c9b 100644 --- a/rules/common/uid.go +++ b/rules/common/uid.go @@ -41,7 +41,10 @@ func (u *Uid) RuleType() C.RuleType { return C.Uid } -func (u *Uid) Match(metadata *C.Metadata) (bool, string) { +func (u *Uid) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { + if helper.FindProcess != nil { + helper.FindProcess() + } if metadata.Uid != 0 { if u.uids.Check(metadata.Uid) { return true, u.adapter @@ -58,7 +61,3 @@ func (u *Uid) Adapter() string { func (u *Uid) Payload() string { return u.oUid } - -func (u *Uid) ShouldFindProcess() bool { - return true -} diff --git a/rules/logic/logic.go b/rules/logic/logic.go index f7b5a987..6a3be8d4 100644 --- a/rules/logic/logic.go +++ b/rules/logic/logic.go @@ -195,11 +195,11 @@ func (logic *Logic) RuleType() C.RuleType { return logic.ruleType } -func matchSubRules(metadata *C.Metadata, name string, subRules map[string][]C.Rule) (bool, string) { +func matchSubRules(metadata *C.Metadata, name string, subRules map[string][]C.Rule, helper C.RuleMatchHelper) (bool, string) { for _, rule := range subRules[name] { - if m, a := rule.Match(metadata); m { + if m, a := rule.Match(metadata, helper); m { if rule.RuleType() == C.SubRules { - return matchSubRules(metadata, rule.Adapter(), subRules) + return matchSubRules(metadata, rule.Adapter(), subRules, helper) } else { return m, a } @@ -208,28 +208,28 @@ func matchSubRules(metadata *C.Metadata, name string, subRules map[string][]C.Ru return false, "" } -func (logic *Logic) Match(metadata *C.Metadata) (bool, string) { +func (logic *Logic) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { switch logic.ruleType { case C.SubRules: - if m, _ := logic.rules[0].Match(metadata); m { - return matchSubRules(metadata, logic.adapter, logic.subRules) + if m, _ := logic.rules[0].Match(metadata, helper); m { + return matchSubRules(metadata, logic.adapter, logic.subRules, helper) } return false, "" case C.NOT: - if m, _ := logic.rules[0].Match(metadata); !m { + if m, _ := logic.rules[0].Match(metadata, helper); !m { return true, logic.adapter } return false, "" case C.OR: for _, rule := range logic.rules { - if m, _ := rule.Match(metadata); m { + if m, _ := rule.Match(metadata, helper); m { return true, logic.adapter } } return false, "" case C.AND: for _, rule := range logic.rules { - if m, _ := rule.Match(metadata); !m { + if m, _ := rule.Match(metadata, helper); !m { return false, logic.adapter } } @@ -266,38 +266,6 @@ func (logic *Logic) Payload() string { return logic.payload } -func (logic *Logic) ShouldResolveIP() bool { - if logic.ruleType == C.SubRules { - for _, rule := range logic.subRules[logic.adapter] { - if rule.ShouldResolveIP() { - return true - } - } - } - for _, rule := range logic.rules { - if rule.ShouldResolveIP() { - return true - } - } - return false -} - -func (logic *Logic) ShouldFindProcess() bool { - if logic.ruleType == C.SubRules { - for _, rule := range logic.subRules[logic.adapter] { - if rule.ShouldFindProcess() { - return true - } - } - } - for _, rule := range logic.rules { - if rule.ShouldFindProcess() { - return true - } - } - return false -} - func (logic *Logic) ProviderNames() (names []string) { for _, rule := range logic.rules { names = append(names, rule.ProviderNames()...) diff --git a/rules/logic_test/logic_test.go b/rules/logic_test/logic_test.go index e88c8578..27a21916 100644 --- a/rules/logic_test/logic_test.go +++ b/rules/logic_test/logic_test.go @@ -16,12 +16,11 @@ func TestAND(t *testing.T) { and, err := NewAND("((DOMAIN,baidu.com),(NETWORK,TCP),(DST-PORT,10001-65535))", "DIRECT", ParseRule) assert.Equal(t, nil, err) assert.Equal(t, "DIRECT", and.Adapter()) - assert.Equal(t, false, and.ShouldResolveIP()) m, _ := and.Match(&C.Metadata{ Host: "baidu.com", NetWork: C.TCP, DstPort: 20000, - }) + }, C.RuleMatchHelper{}) assert.Equal(t, true, m) and, err = NewAND("(DOMAIN,baidu.com),(NETWORK,TCP),(DST-PORT,10001-65535))", "DIRECT", ParseRule) @@ -36,7 +35,7 @@ func TestNOT(t *testing.T) { assert.Equal(t, nil, err) m, _ := not.Match(&C.Metadata{ DstPort: 6100, - }) + }, C.RuleMatchHelper{}) assert.Equal(t, false, m) _, err = NewNOT("((DST-PORT,5600-6666),(DOMAIN,baidu.com))", "DIRECT", ParseRule) @@ -51,7 +50,6 @@ func TestOR(t *testing.T) { assert.Equal(t, nil, err) m, _ := or.Match(&C.Metadata{ NetWork: C.TCP, - }) + }, C.RuleMatchHelper{}) assert.Equal(t, true, m) - assert.Equal(t, false, or.ShouldResolveIP()) } diff --git a/rules/provider/classical_strategy.go b/rules/provider/classical_strategy.go index 95a772e8..3b74f5cd 100644 --- a/rules/provider/classical_strategy.go +++ b/rules/provider/classical_strategy.go @@ -10,20 +10,18 @@ import ( ) type classicalStrategy struct { - rules []C.Rule - count int - shouldResolveIP bool - shouldFindProcess bool - parse func(tp, payload, target string, params []string) (parsed C.Rule, parseErr error) + rules []C.Rule + count int + parse func(tp, payload, target string, params []string) (parsed C.Rule, parseErr error) } func (c *classicalStrategy) Behavior() P.RuleBehavior { return P.Classical } -func (c *classicalStrategy) Match(metadata *C.Metadata) bool { +func (c *classicalStrategy) Match(metadata *C.Metadata, helper C.RuleMatchHelper) bool { for _, rule := range c.rules { - if m, _ := rule.Match(metadata); m { + if m, _ := rule.Match(metadata, helper); m { return true } } @@ -35,39 +33,17 @@ func (c *classicalStrategy) Count() int { return c.count } -func (c *classicalStrategy) ShouldResolveIP() bool { - return c.shouldResolveIP -} - -func (c *classicalStrategy) ShouldFindProcess() bool { - return c.shouldFindProcess -} - func (c *classicalStrategy) Reset() { c.rules = nil c.count = 0 - c.shouldFindProcess = false - c.shouldResolveIP = false } func (c *classicalStrategy) Insert(rule string) { ruleType, rule, params := ruleParse(rule) - - if ruleType == "PROCESS-NAME" { - c.shouldFindProcess = true - } - r, err := c.parse(ruleType, rule, "", params) if err != nil { log.Warnln("parse classical rule error: %s", err.Error()) } else { - if r.ShouldResolveIP() { - c.shouldResolveIP = true - } - if r.ShouldFindProcess() { - c.shouldFindProcess = true - } - c.rules = append(c.rules, r) c.count++ } diff --git a/rules/provider/domain_strategy.go b/rules/provider/domain_strategy.go index b893f038..2e195027 100644 --- a/rules/provider/domain_strategy.go +++ b/rules/provider/domain_strategy.go @@ -23,11 +23,7 @@ func (d *domainStrategy) Behavior() P.RuleBehavior { return P.Domain } -func (d *domainStrategy) ShouldFindProcess() bool { - return false -} - -func (d *domainStrategy) Match(metadata *C.Metadata) bool { +func (d *domainStrategy) Match(metadata *C.Metadata, helper C.RuleMatchHelper) bool { return d.domainSet != nil && d.domainSet.Has(metadata.RuleHost()) } @@ -35,10 +31,6 @@ func (d *domainStrategy) Count() int { return d.count } -func (d *domainStrategy) ShouldResolveIP() bool { - return false -} - func (d *domainStrategy) Reset() { d.domainTrie = trie.New[struct{}]() d.domainSet = nil diff --git a/rules/provider/ipcidr_strategy.go b/rules/provider/ipcidr_strategy.go index 9efffed9..78a56cbb 100644 --- a/rules/provider/ipcidr_strategy.go +++ b/rules/provider/ipcidr_strategy.go @@ -14,21 +14,19 @@ import ( ) type ipcidrStrategy struct { - count int - shouldResolveIP bool - cidrSet *cidr.IpCidrSet - //trie *trie.IpCidrTrie + count int + cidrSet *cidr.IpCidrSet + //trie *trie.IpCidrTrie } func (i *ipcidrStrategy) Behavior() P.RuleBehavior { return P.IPCIDR } -func (i *ipcidrStrategy) ShouldFindProcess() bool { - return false -} - -func (i *ipcidrStrategy) Match(metadata *C.Metadata) bool { +func (i *ipcidrStrategy) Match(metadata *C.Metadata, helper C.RuleMatchHelper) bool { + if helper.ResolveIP != nil { + helper.ResolveIP() + } // return i.trie != nil && i.trie.IsContain(metadata.DstIP.AsSlice()) return i.cidrSet != nil && i.cidrSet.IsContain(metadata.DstIP) } @@ -37,15 +35,10 @@ func (i *ipcidrStrategy) Count() int { return i.count } -func (i *ipcidrStrategy) ShouldResolveIP() bool { - return i.shouldResolveIP -} - func (i *ipcidrStrategy) Reset() { // i.trie = trie.NewIpCidrTrie() i.cidrSet = cidr.NewIpCidrSet() i.count = 0 - i.shouldResolveIP = false } func (i *ipcidrStrategy) Insert(rule string) { @@ -54,7 +47,6 @@ func (i *ipcidrStrategy) Insert(rule string) { if err != nil { log.Warnln("invalid Ipcidr:[%s]", rule) } else { - i.shouldResolveIP = true i.count++ } } @@ -70,9 +62,6 @@ func (i *ipcidrStrategy) FromMrs(r io.Reader, count int) error { } i.count = count i.cidrSet = cidrSet - if i.count > 0 { - i.shouldResolveIP = true - } return nil } diff --git a/rules/provider/provider.go b/rules/provider/provider.go index 23a5781f..f197a80a 100644 --- a/rules/provider/provider.go +++ b/rules/provider/provider.go @@ -46,10 +46,8 @@ type providerForApi struct { type ruleStrategy interface { Behavior() P.RuleBehavior - Match(metadata *C.Metadata) bool + Match(metadata *C.Metadata, helper C.RuleMatchHelper) bool Count() int - ShouldResolveIP() bool - ShouldFindProcess() bool Reset() Insert(rule string) FinishInsert() @@ -79,16 +77,8 @@ func (bp *baseProvider) Count() int { return bp.strategy.Count() } -func (bp *baseProvider) Match(metadata *C.Metadata) bool { - return bp.strategy != nil && bp.strategy.Match(metadata) -} - -func (bp *baseProvider) ShouldResolveIP() bool { - return bp.strategy.ShouldResolveIP() -} - -func (bp *baseProvider) ShouldFindProcess() bool { - return bp.strategy.ShouldFindProcess() +func (bp *baseProvider) Match(metadata *C.Metadata, helper C.RuleMatchHelper) bool { + return bp.strategy != nil && bp.strategy.Match(metadata, helper) } func (bp *baseProvider) Strategy() any { diff --git a/rules/provider/rule_set.go b/rules/provider/rule_set.go index 2ad0bd3d..320307e7 100644 --- a/rules/provider/rule_set.go +++ b/rules/provider/rule_set.go @@ -10,47 +10,40 @@ import ( type RuleSet struct { *common.Base - ruleProviderName string - adapter string - isSrc bool - noResolveIP bool - shouldFindProcess bool -} - -func (rs *RuleSet) ShouldFindProcess() bool { - if rs.shouldFindProcess { - return true - } - if provider, ok := rs.getProvider(); ok { - return provider.ShouldFindProcess() - } - return false + ruleProviderName string + adapter string + isSrc bool + noResolveIP bool } func (rs *RuleSet) RuleType() C.RuleType { return C.RuleSet } -func (rs *RuleSet) Match(metadata *C.Metadata) (bool, string) { +func (rs *RuleSet) Match(metadata *C.Metadata, helper C.RuleMatchHelper) (bool, string) { if provider, ok := rs.getProvider(); ok { if rs.isSrc { metadata.SwapSrcDst() defer metadata.SwapSrcDst() + + helper.ResolveIP = nil // src mode should not resolve ip + } else if rs.noResolveIP { + helper.ResolveIP = nil } - return provider.Match(metadata), rs.adapter + return provider.Match(metadata, helper), rs.adapter } return false, "" } // MatchDomain implements C.DomainMatcher func (rs *RuleSet) MatchDomain(domain string) bool { - ok, _ := rs.Match(&C.Metadata{Host: domain}) + ok, _ := rs.Match(&C.Metadata{Host: domain}, C.RuleMatchHelper{}) return ok } // MatchIp implements C.IpMatcher func (rs *RuleSet) MatchIp(ip netip.Addr) bool { - ok, _ := rs.Match(&C.Metadata{DstIP: ip}) + ok, _ := rs.Match(&C.Metadata{DstIP: ip}, C.RuleMatchHelper{}) return ok } @@ -62,16 +55,6 @@ func (rs *RuleSet) Payload() string { return rs.ruleProviderName } -func (rs *RuleSet) ShouldResolveIP() bool { - if rs.noResolveIP { - return false - } - if provider, ok := rs.getProvider(); ok { - return provider.ShouldResolveIP() - } - return false -} - func (rs *RuleSet) ProviderNames() []string { return []string{rs.ruleProviderName} } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 744a01b1..1314e540 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -590,10 +590,6 @@ func logMetadata(metadata *C.Metadata, rule C.Rule, remoteConn C.Connection) { } } -func shouldResolveIP(rule C.Rule, metadata *C.Metadata) bool { - return rule.ShouldResolveIP() && metadata.Host != "" && !metadata.DstIP.IsValid() -} - func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { configMux.RLock() defer configMux.RUnlock() @@ -607,9 +603,9 @@ func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { resolved = true } - for _, rule := range getRules(metadata) { - if !resolved && shouldResolveIP(rule, metadata) { - func() { + helper := C.RuleMatchHelper{ + ResolveIP: func() { + if !resolved && metadata.Host != "" && !metadata.Resolved() { ctx, cancel := context.WithTimeout(context.Background(), resolver.DefaultDNSTimeout) defer cancel() ip, err := resolver.ResolveIP(ctx, metadata.Host) @@ -620,37 +616,44 @@ func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { metadata.DstIP = ip } resolved = true - }() - } + } + }, + FindProcess: func() { + if attemptProcessLookup && !findProcessMode.Off() { + attemptProcessLookup = false + if !features.CMFA { + // normal check for process + uid, path, err := P.FindProcessName(metadata.NetWork.String(), metadata.SrcIP, int(metadata.SrcPort)) + if err != nil { + log.Debugln("[Process] find process error for %s: %v", metadata.String(), err) + } else { + metadata.Process = filepath.Base(path) + metadata.ProcessPath = path + metadata.Uid = uid - if attemptProcessLookup && !findProcessMode.Off() && (findProcessMode.Always() || rule.ShouldFindProcess()) { - attemptProcessLookup = false - if !features.CMFA { - // normal check for process - uid, path, err := P.FindProcessName(metadata.NetWork.String(), metadata.SrcIP, int(metadata.SrcPort)) - if err != nil { - log.Debugln("[Process] find process error for %s: %v", metadata.String(), err) + if pkg, err := P.FindPackageName(metadata); err == nil { // for android (not CMFA) package names + metadata.Process = pkg + } + } } else { - metadata.Process = filepath.Base(path) - metadata.ProcessPath = path - metadata.Uid = uid - - if pkg, err := P.FindPackageName(metadata); err == nil { // for android (not CMFA) package names + // check package names + pkg, err := P.FindPackageName(metadata) + if err != nil { + log.Debugln("[Process] find process error for %s: %v", metadata.String(), err) + } else { metadata.Process = pkg } } - } else { - // check package names - pkg, err := P.FindPackageName(metadata) - if err != nil { - log.Debugln("[Process] find process error for %s: %v", metadata.String(), err) - } else { - metadata.Process = pkg - } } - } + }, + } - if matched, ada := rule.Match(metadata); matched { + if findProcessMode.Always() { + helper.FindProcess() + } + + for _, rule := range getRules(metadata) { + if matched, ada := rule.Match(metadata, helper); matched { adapter, ok := proxies[ada] if !ok { continue