chore: better IPSet code

This commit is contained in:
wwqgtxx 2024-01-13 18:15:30 +08:00
parent e860497c0c
commit edf318bae0
3 changed files with 33 additions and 19 deletions

View File

@ -1,12 +1,16 @@
package cidr package cidr
import ( import (
"go4.org/netipx" "fmt"
"net/netip" "net/netip"
"unsafe"
"go4.org/netipx"
) )
type IpCidrSet struct { type IpCidrSet struct {
Ranges *netipx.IPSet // must same with netipx.IPSet
rr []netipx.IPRange
} }
func NewIpCidrSet() *IpCidrSet { func NewIpCidrSet() *IpCidrSet {
@ -18,15 +22,15 @@ func (set *IpCidrSet) AddIpCidrForString(ipCidr string) error {
if err != nil { if err != nil {
return err return err
} }
err = set.AddIpCidr(prefix) return set.AddIpCidr(prefix)
return nil
} }
func (set *IpCidrSet) AddIpCidr(ipCidr netip.Prefix) (err error) { func (set *IpCidrSet) AddIpCidr(ipCidr netip.Prefix) (err error) {
var b netipx.IPSetBuilder if r := netipx.RangeOfPrefix(ipCidr); r.IsValid() {
b.AddSet(set.Ranges) set.rr = append(set.rr, r)
b.AddPrefix(ipCidr) } else {
set.Ranges, err = b.IPSet() err = fmt.Errorf("not valid ipcidr range: %s", ipCidr)
}
return return
} }
@ -39,10 +43,24 @@ func (set *IpCidrSet) IsContainForString(ipString string) bool {
} }
func (set *IpCidrSet) IsContain(ip netip.Addr) bool { func (set *IpCidrSet) IsContain(ip netip.Addr) bool {
if set.Ranges == nil { return set.toIPSet().Contains(ip.WithZone(""))
return false
}
return set.Ranges.Contains(ip.WithZone(""))
} }
func (set *IpCidrSet) Merge() {} func (set *IpCidrSet) Merge() error {
var b netipx.IPSetBuilder
b.AddSet(set.toIPSet())
i, err := b.IPSet()
if err != nil {
return err
}
set.fromIPSet(i)
return nil
}
func (set *IpCidrSet) toIPSet() *netipx.IPSet {
return (*netipx.IPSet)(unsafe.Pointer(set))
}
func (set *IpCidrSet) fromIPSet(i *netipx.IPSet) {
*set = *(*IpCidrSet)(unsafe.Pointer(i))
}

View File

@ -1,9 +1,7 @@
package cidr package cidr
import ( import (
"go4.org/netipx"
"testing" "testing"
"unsafe"
) )
func TestIpv4(t *testing.T) { func TestIpv4(t *testing.T) {
@ -99,7 +97,7 @@ func TestMerge(t *testing.T) {
set.AddIpCidrForString(test.ipCidr2) set.AddIpCidrForString(test.ipCidr2)
set.Merge() set.Merge()
rangesLen := len(*(*[]netipx.IPRange)(unsafe.Pointer(set.Ranges))) rangesLen := len(set.rr)
if rangesLen != test.expectedLen { if rangesLen != test.expectedLen {
t.Errorf("Expected len: %v, got: %v", test.expectedLen, rangesLen) t.Errorf("Expected len: %v, got: %v", test.expectedLen, rangesLen)

View File

@ -137,9 +137,7 @@ func (m *GeoIPMatcher) Init(cidrs []*CIDR) error {
return fmt.Errorf("error when loading GeoIP: %w", err) return fmt.Errorf("error when loading GeoIP: %w", err)
} }
} }
m.cidrSet.Merge() return m.cidrSet.Merge()
return nil
} }
func (m *GeoIPMatcher) SetReverseMatch(isReverseMatch bool) { func (m *GeoIPMatcher) SetReverseMatch(isReverseMatch bool) {