diff --git a/adapter/adapter.go b/adapter/adapter.go index 53f7c6d9..8a9143dc 100644 --- a/adapter/adapter.go +++ b/adapter/adapter.go @@ -14,10 +14,10 @@ import ( "github.com/metacubex/mihomo/common/atomic" "github.com/metacubex/mihomo/common/queue" "github.com/metacubex/mihomo/common/utils" + "github.com/metacubex/mihomo/common/xsync" "github.com/metacubex/mihomo/component/ca" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/log" - "github.com/puzpuzpuz/xsync/v3" ) var UnifiedDelay = atomic.NewBool(false) @@ -35,7 +35,7 @@ type Proxy struct { C.ProxyAdapter alive atomic.Bool history *queue.Queue[C.DelayHistory] - extra *xsync.MapOf[string, *internalProxyState] + extra *xsync.Map[string, *internalProxyState] } // Adapter implements C.Proxy @@ -293,7 +293,7 @@ func NewProxy(adapter C.ProxyAdapter) *Proxy { ProxyAdapter: adapter, history: queue.New[C.DelayHistory](defaultHistoriesNum), alive: atomic.NewBool(true), - extra: xsync.NewMapOf[string, *internalProxyState]()} + extra: xsync.NewMap[string, *internalProxyState]()} } func urlToMetadata(rawURL string) (addr C.Metadata, err error) { diff --git a/common/maphash/common.go b/common/maphash/common.go new file mode 100644 index 00000000..14b4aaa9 --- /dev/null +++ b/common/maphash/common.go @@ -0,0 +1,19 @@ +package maphash + +import "hash/maphash" + +type Seed = maphash.Seed + +func MakeSeed() Seed { + return maphash.MakeSeed() +} + +type Hash = maphash.Hash + +func Bytes(seed Seed, b []byte) uint64 { + return maphash.Bytes(seed, b) +} + +func String(seed Seed, s string) uint64 { + return maphash.String(seed, s) +} diff --git a/common/maphash/comparable_go120.go b/common/maphash/comparable_go120.go new file mode 100644 index 00000000..05256d67 --- /dev/null +++ b/common/maphash/comparable_go120.go @@ -0,0 +1,140 @@ +//go:build !go1.24 + +package maphash + +import "unsafe" + +func Comparable[T comparable](s Seed, v T) uint64 { + return comparableHash(*(*seedTyp)(unsafe.Pointer(&s)), v) +} + +func comparableHash[T comparable](seed seedTyp, v T) uint64 { + s := seed.s + var m map[T]struct{} + mTyp := iTypeOf(m) + var hasher func(unsafe.Pointer, uintptr) uintptr + hasher = (*iMapType)(unsafe.Pointer(mTyp)).Hasher + + p := escape(unsafe.Pointer(&v)) + + if ptrSize == 8 { + return uint64(hasher(p, uintptr(s))) + } + lo := hasher(p, uintptr(s)) + hi := hasher(p, uintptr(s>>32)) + return uint64(hi)<<32 | uint64(lo) +} + +// WriteComparable adds x to the data hashed by h. +func WriteComparable[T comparable](h *Hash, x T) { + // writeComparable (not in purego mode) directly operates on h.state + // without using h.buf. Mix in the buffer length so it won't + // commute with a buffered write, which either changes h.n or changes + // h.state. + hash := (*hashTyp)(unsafe.Pointer(h)) + if hash.n != 0 { + hash.state.s = comparableHash(hash.state, hash.n) + } + hash.state.s = comparableHash(hash.state, x) +} + +// go/src/hash/maphash/maphash.go +type hashTyp struct { + _ [0]func() // not comparable + seed seedTyp // initial seed used for this hash + state seedTyp // current hash of all flushed bytes + buf [128]byte // unflushed byte buffer + n int // number of unflushed bytes +} + +type seedTyp struct { + s uint64 +} + +type iTFlag uint8 +type iKind uint8 +type iNameOff int32 + +// TypeOff is the offset to a type from moduledata.types. See resolveTypeOff in runtime. +type iTypeOff int32 + +type iType struct { + Size_ uintptr + PtrBytes uintptr // number of (prefix) bytes in the type that can contain pointers + Hash uint32 // hash of type; avoids computation in hash tables + TFlag iTFlag // extra type information flags + Align_ uint8 // alignment of variable with this type + FieldAlign_ uint8 // alignment of struct field with this type + Kind_ iKind // enumeration for C + // function for comparing objects of this type + // (ptr to object A, ptr to object B) -> ==? + Equal func(unsafe.Pointer, unsafe.Pointer) bool + // GCData stores the GC type data for the garbage collector. + // Normally, GCData points to a bitmask that describes the + // ptr/nonptr fields of the type. The bitmask will have at + // least PtrBytes/ptrSize bits. + // If the TFlagGCMaskOnDemand bit is set, GCData is instead a + // **byte and the pointer to the bitmask is one dereference away. + // The runtime will build the bitmask if needed. + // (See runtime/type.go:getGCMask.) + // Note: multiple types may have the same value of GCData, + // including when TFlagGCMaskOnDemand is set. The types will, of course, + // have the same pointer layout (but not necessarily the same size). + GCData *byte + Str iNameOff // string form + PtrToThis iTypeOff // type for pointer to this type, may be zero +} + +type iMapType struct { + iType + Key *iType + Elem *iType + Group *iType // internal type representing a slot group + // function for hashing keys (ptr to key, seed) -> hash + Hasher func(unsafe.Pointer, uintptr) uintptr +} + +func iTypeOf(a any) *iType { + eface := *(*iEmptyInterface)(unsafe.Pointer(&a)) + // Types are either static (for compiler-created types) or + // heap-allocated but always reachable (for reflection-created + // types, held in the central map). So there is no need to + // escape types. noescape here help avoid unnecessary escape + // of v. + return (*iType)(noescape(unsafe.Pointer(eface.Type))) +} + +type iEmptyInterface struct { + Type *iType + Data unsafe.Pointer +} + +// noescape hides a pointer from escape analysis. noescape is +// the identity function but escape analysis doesn't think the +// output depends on the input. noescape is inlined and currently +// compiles down to zero instructions. +// USE CAREFULLY! +// +// nolint:all +// +//go:nosplit +//goland:noinspection ALL +func noescape(p unsafe.Pointer) unsafe.Pointer { + x := uintptr(p) + return unsafe.Pointer(x ^ 0) +} + +var alwaysFalse bool +var escapeSink any + +// escape forces any pointers in x to escape to the heap. +func escape[T any](x T) T { + if alwaysFalse { + escapeSink = x + } + return x +} + +// ptrSize is the size of a pointer in bytes - unsafe.Sizeof(uintptr(0)) but as an ideal constant. +// It is also the size of the machine's native word size (that is, 4 on 32-bit systems, 8 on 64-bit). +const ptrSize = 4 << (^uintptr(0) >> 63) diff --git a/common/maphash/comparable_go124.go b/common/maphash/comparable_go124.go new file mode 100644 index 00000000..3a96edb6 --- /dev/null +++ b/common/maphash/comparable_go124.go @@ -0,0 +1,13 @@ +//go:build go1.24 + +package maphash + +import "hash/maphash" + +func Comparable[T comparable](seed Seed, v T) uint64 { + return maphash.Comparable(seed, v) +} + +func WriteComparable[T comparable](h *Hash, x T) { + maphash.WriteComparable(h, x) +} diff --git a/common/maphash/maphash_test.go b/common/maphash/maphash_test.go new file mode 100644 index 00000000..73887f27 --- /dev/null +++ b/common/maphash/maphash_test.go @@ -0,0 +1,532 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package maphash + +import ( + "bytes" + "fmt" + "hash" + "math" + "reflect" + "strings" + "testing" + "unsafe" + + rand "github.com/metacubex/randv2" +) + +func TestUnseededHash(t *testing.T) { + m := map[uint64]struct{}{} + for i := 0; i < 1000; i++ { + h := new(Hash) + m[h.Sum64()] = struct{}{} + } + if len(m) < 900 { + t.Errorf("empty hash not sufficiently random: got %d, want 1000", len(m)) + } +} + +func TestSeededHash(t *testing.T) { + s := MakeSeed() + m := map[uint64]struct{}{} + for i := 0; i < 1000; i++ { + h := new(Hash) + h.SetSeed(s) + m[h.Sum64()] = struct{}{} + } + if len(m) != 1 { + t.Errorf("seeded hash is random: got %d, want 1", len(m)) + } +} + +func TestHashGrouping(t *testing.T) { + b := bytes.Repeat([]byte("foo"), 100) + hh := make([]*Hash, 7) + for i := range hh { + hh[i] = new(Hash) + } + for _, h := range hh[1:] { + h.SetSeed(hh[0].Seed()) + } + hh[0].Write(b) + hh[1].WriteString(string(b)) + + writeByte := func(h *Hash, b byte) { + err := h.WriteByte(b) + if err != nil { + t.Fatalf("WriteByte: %v", err) + } + } + writeSingleByte := func(h *Hash, b byte) { + _, err := h.Write([]byte{b}) + if err != nil { + t.Fatalf("Write single byte: %v", err) + } + } + writeStringSingleByte := func(h *Hash, b byte) { + _, err := h.WriteString(string([]byte{b})) + if err != nil { + t.Fatalf("WriteString single byte: %v", err) + } + } + + for i, x := range b { + writeByte(hh[2], x) + writeSingleByte(hh[3], x) + if i == 0 { + writeByte(hh[4], x) + } else { + writeSingleByte(hh[4], x) + } + writeStringSingleByte(hh[5], x) + if i == 0 { + writeByte(hh[6], x) + } else { + writeStringSingleByte(hh[6], x) + } + } + + sum := hh[0].Sum64() + for i, h := range hh { + if sum != h.Sum64() { + t.Errorf("hash %d not identical to a single Write", i) + } + } + + if sum1 := Bytes(hh[0].Seed(), b); sum1 != hh[0].Sum64() { + t.Errorf("hash using Bytes not identical to a single Write") + } + + if sum1 := String(hh[0].Seed(), string(b)); sum1 != hh[0].Sum64() { + t.Errorf("hash using String not identical to a single Write") + } +} + +func TestHashBytesVsString(t *testing.T) { + s := "foo" + b := []byte(s) + h1 := new(Hash) + h2 := new(Hash) + h2.SetSeed(h1.Seed()) + n1, err1 := h1.WriteString(s) + if n1 != len(s) || err1 != nil { + t.Fatalf("WriteString(s) = %d, %v, want %d, nil", n1, err1, len(s)) + } + n2, err2 := h2.Write(b) + if n2 != len(b) || err2 != nil { + t.Fatalf("Write(b) = %d, %v, want %d, nil", n2, err2, len(b)) + } + if h1.Sum64() != h2.Sum64() { + t.Errorf("hash of string and bytes not identical") + } +} + +func TestHashHighBytes(t *testing.T) { + // See issue 34925. + const N = 10 + m := map[uint64]struct{}{} + for i := 0; i < N; i++ { + h := new(Hash) + h.WriteString("foo") + m[h.Sum64()>>32] = struct{}{} + } + if len(m) < N/2 { + t.Errorf("from %d seeds, wanted at least %d different hashes; got %d", N, N/2, len(m)) + } +} + +func TestRepeat(t *testing.T) { + h1 := new(Hash) + h1.WriteString("testing") + sum1 := h1.Sum64() + + h1.Reset() + h1.WriteString("testing") + sum2 := h1.Sum64() + + if sum1 != sum2 { + t.Errorf("different sum after resetting: %#x != %#x", sum1, sum2) + } + + h2 := new(Hash) + h2.SetSeed(h1.Seed()) + h2.WriteString("testing") + sum3 := h2.Sum64() + + if sum1 != sum3 { + t.Errorf("different sum on the same seed: %#x != %#x", sum1, sum3) + } +} + +func TestSeedFromSum64(t *testing.T) { + h1 := new(Hash) + h1.WriteString("foo") + x := h1.Sum64() // seed generated here + h2 := new(Hash) + h2.SetSeed(h1.Seed()) + h2.WriteString("foo") + y := h2.Sum64() + if x != y { + t.Errorf("hashes don't match: want %x, got %x", x, y) + } +} + +func TestSeedFromSeed(t *testing.T) { + h1 := new(Hash) + h1.WriteString("foo") + _ = h1.Seed() // seed generated here + x := h1.Sum64() + h2 := new(Hash) + h2.SetSeed(h1.Seed()) + h2.WriteString("foo") + y := h2.Sum64() + if x != y { + t.Errorf("hashes don't match: want %x, got %x", x, y) + } +} + +func TestSeedFromFlush(t *testing.T) { + b := make([]byte, 65) + h1 := new(Hash) + h1.Write(b) // seed generated here + x := h1.Sum64() + h2 := new(Hash) + h2.SetSeed(h1.Seed()) + h2.Write(b) + y := h2.Sum64() + if x != y { + t.Errorf("hashes don't match: want %x, got %x", x, y) + } +} + +func TestSeedFromReset(t *testing.T) { + h1 := new(Hash) + h1.WriteString("foo") + h1.Reset() // seed generated here + h1.WriteString("foo") + x := h1.Sum64() + h2 := new(Hash) + h2.SetSeed(h1.Seed()) + h2.WriteString("foo") + y := h2.Sum64() + if x != y { + t.Errorf("hashes don't match: want %x, got %x", x, y) + } +} + +func negativeZero[T float32 | float64]() T { + var f T + f = -f + return f +} + +func TestComparable(t *testing.T) { + testComparable(t, int64(2)) + testComparable(t, uint64(8)) + testComparable(t, uintptr(12)) + testComparable(t, any("s")) + testComparable(t, "s") + testComparable(t, true) + testComparable(t, new(float64)) + testComparable(t, float64(9)) + testComparable(t, complex128(9i+1)) + testComparable(t, struct{}{}) + testComparable(t, struct { + i int + u uint + b bool + f float64 + p *int + a any + }{i: 9, u: 1, b: true, f: 9.9, p: new(int), a: 1}) + type S struct { + s string + } + s1 := S{s: heapStr(t)} + s2 := S{s: heapStr(t)} + if unsafe.StringData(s1.s) == unsafe.StringData(s2.s) { + t.Fatalf("unexpected two heapStr ptr equal") + } + if s1.s != s2.s { + t.Fatalf("unexpected two heapStr value not equal") + } + testComparable(t, s1, s2) + testComparable(t, s1.s, s2.s) + testComparable(t, float32(0), negativeZero[float32]()) + testComparable(t, float64(0), negativeZero[float64]()) + testComparableNoEqual(t, math.NaN(), math.NaN()) + testComparableNoEqual(t, [2]string{"a", ""}, [2]string{"", "a"}) + testComparableNoEqual(t, struct{ a, b string }{"foo", ""}, struct{ a, b string }{"", "foo"}) + testComparableNoEqual(t, struct{ a, b any }{int(0), struct{}{}}, struct{ a, b any }{struct{}{}, int(0)}) +} + +func testComparableNoEqual[T comparable](t *testing.T, v1, v2 T) { + seed := MakeSeed() + if Comparable(seed, v1) == Comparable(seed, v2) { + t.Fatalf("Comparable(seed, %v) == Comparable(seed, %v)", v1, v2) + } +} + +var heapStrValue = []byte("aTestString") + +func heapStr(t *testing.T) string { + return string(heapStrValue) +} + +func testComparable[T comparable](t *testing.T, v T, v2 ...T) { + t.Run(TypeFor[T]().String(), func(t *testing.T) { + var a, b T = v, v + if len(v2) != 0 { + b = v2[0] + } + var pa *T = &a + seed := MakeSeed() + if Comparable(seed, a) != Comparable(seed, b) { + t.Fatalf("Comparable(seed, %v) != Comparable(seed, %v)", a, b) + } + old := Comparable(seed, pa) + stackGrow(8192) + new := Comparable(seed, pa) + if old != new { + t.Fatal("Comparable(seed, ptr) != Comparable(seed, ptr)") + } + }) +} + +var use byte + +//go:noinline +func stackGrow(dep int) { + if dep == 0 { + return + } + var local [1024]byte + // make sure local is allocated on the stack. + local[rand.Uint64()%1024] = byte(rand.Uint64()) + use = local[rand.Uint64()%1024] + stackGrow(dep - 1) +} + +func TestWriteComparable(t *testing.T) { + testWriteComparable(t, int64(2)) + testWriteComparable(t, uint64(8)) + testWriteComparable(t, uintptr(12)) + testWriteComparable(t, any("s")) + testWriteComparable(t, "s") + testComparable(t, true) + testWriteComparable(t, new(float64)) + testWriteComparable(t, float64(9)) + testWriteComparable(t, complex128(9i+1)) + testWriteComparable(t, struct{}{}) + testWriteComparable(t, struct { + i int + u uint + b bool + f float64 + p *int + a any + }{i: 9, u: 1, b: true, f: 9.9, p: new(int), a: 1}) + type S struct { + s string + } + s1 := S{s: heapStr(t)} + s2 := S{s: heapStr(t)} + if unsafe.StringData(s1.s) == unsafe.StringData(s2.s) { + t.Fatalf("unexpected two heapStr ptr equal") + } + if s1.s != s2.s { + t.Fatalf("unexpected two heapStr value not equal") + } + testWriteComparable(t, s1, s2) + testWriteComparable(t, s1.s, s2.s) + testWriteComparable(t, float32(0), negativeZero[float32]()) + testWriteComparable(t, float64(0), negativeZero[float64]()) + testWriteComparableNoEqual(t, math.NaN(), math.NaN()) + testWriteComparableNoEqual(t, [2]string{"a", ""}, [2]string{"", "a"}) + testWriteComparableNoEqual(t, struct{ a, b string }{"foo", ""}, struct{ a, b string }{"", "foo"}) + testWriteComparableNoEqual(t, struct{ a, b any }{int(0), struct{}{}}, struct{ a, b any }{struct{}{}, int(0)}) +} + +func testWriteComparableNoEqual[T comparable](t *testing.T, v1, v2 T) { + seed := MakeSeed() + h1 := Hash{} + h2 := Hash{} + *(*Seed)(unsafe.Pointer(&h1)), *(*Seed)(unsafe.Pointer(&h2)) = seed, seed + WriteComparable(&h1, v1) + WriteComparable(&h2, v2) + if h1.Sum64() == h2.Sum64() { + t.Fatalf("WriteComparable(seed, %v) == WriteComparable(seed, %v)", v1, v2) + } + +} + +func testWriteComparable[T comparable](t *testing.T, v T, v2 ...T) { + t.Run(TypeFor[T]().String(), func(t *testing.T) { + var a, b T = v, v + if len(v2) != 0 { + b = v2[0] + } + var pa *T = &a + h1 := Hash{} + h2 := Hash{} + *(*Seed)(unsafe.Pointer(&h1)) = MakeSeed() + h2 = h1 + WriteComparable(&h1, a) + WriteComparable(&h2, b) + if h1.Sum64() != h2.Sum64() { + t.Fatalf("WriteComparable(h, %v) != WriteComparable(h, %v)", a, b) + } + WriteComparable(&h1, pa) + old := h1.Sum64() + stackGrow(8192) + WriteComparable(&h2, pa) + new := h2.Sum64() + if old != new { + t.Fatal("WriteComparable(seed, ptr) != WriteComparable(seed, ptr)") + } + }) +} + +func TestComparableShouldPanic(t *testing.T) { + s := []byte("s") + a := any(s) + defer func() { + e := recover() + err, ok := e.(error) + if !ok { + t.Fatalf("Comaparable(any([]byte)) should panic") + } + want := "hash of unhashable type []uint8" + if s := err.Error(); !strings.Contains(s, want) { + t.Fatalf("want %s, got %s", want, s) + } + }() + Comparable(MakeSeed(), a) +} + +func TestWriteComparableNoncommute(t *testing.T) { + seed := MakeSeed() + var h1, h2 Hash + h1.SetSeed(seed) + h2.SetSeed(seed) + + h1.WriteString("abc") + WriteComparable(&h1, 123) + WriteComparable(&h2, 123) + h2.WriteString("abc") + + if h1.Sum64() == h2.Sum64() { + t.Errorf("WriteComparable and WriteString unexpectedly commute") + } +} + +func TestComparableAllocations(t *testing.T) { + t.Skip("test broken in old golang version") + seed := MakeSeed() + x := heapStr(t) + allocs := testing.AllocsPerRun(10, func() { + s := "s" + x + Comparable(seed, s) + }) + if allocs > 0 { + t.Errorf("got %v allocs, want 0", allocs) + } + + type S struct { + a int + b string + } + allocs = testing.AllocsPerRun(10, func() { + s := S{123, "s" + x} + Comparable(seed, s) + }) + if allocs > 0 { + t.Errorf("got %v allocs, want 0", allocs) + } +} + +// Make sure a Hash implements the hash.Hash and hash.Hash64 interfaces. +var _ hash.Hash = &Hash{} +var _ hash.Hash64 = &Hash{} + +func benchmarkSize(b *testing.B, size int) { + h := &Hash{} + buf := make([]byte, size) + s := string(buf) + + b.Run("Write", func(b *testing.B) { + b.SetBytes(int64(size)) + for i := 0; i < b.N; i++ { + h.Reset() + h.Write(buf) + h.Sum64() + } + }) + + b.Run("Bytes", func(b *testing.B) { + b.SetBytes(int64(size)) + seed := h.Seed() + for i := 0; i < b.N; i++ { + Bytes(seed, buf) + } + }) + + b.Run("String", func(b *testing.B) { + b.SetBytes(int64(size)) + seed := h.Seed() + for i := 0; i < b.N; i++ { + String(seed, s) + } + }) +} + +func BenchmarkHash(b *testing.B) { + sizes := []int{4, 8, 16, 32, 64, 256, 320, 1024, 4096, 16384} + for _, size := range sizes { + b.Run(fmt.Sprint("n=", size), func(b *testing.B) { + benchmarkSize(b, size) + }) + } +} + +func benchmarkComparable[T comparable](b *testing.B, v T) { + b.Run(TypeFor[T]().String(), func(b *testing.B) { + seed := MakeSeed() + for i := 0; i < b.N; i++ { + Comparable(seed, v) + } + }) +} + +func BenchmarkComparable(b *testing.B) { + type testStruct struct { + i int + u uint + b bool + f float64 + p *int + a any + } + benchmarkComparable(b, int64(2)) + benchmarkComparable(b, uint64(8)) + benchmarkComparable(b, uintptr(12)) + benchmarkComparable(b, any("s")) + benchmarkComparable(b, "s") + benchmarkComparable(b, true) + benchmarkComparable(b, new(float64)) + benchmarkComparable(b, float64(9)) + benchmarkComparable(b, complex128(9i+1)) + benchmarkComparable(b, struct{}{}) + benchmarkComparable(b, testStruct{i: 9, u: 1, b: true, f: 9.9, p: new(int), a: 1}) +} + +// TypeFor returns the [Type] that represents the type argument T. +func TypeFor[T any]() reflect.Type { + var v T + if t := reflect.TypeOf(v); t != nil { + return t // optimize for T being a non-interface kind + } + return reflect.TypeOf((*T)(nil)).Elem() // only for an interface kind +} diff --git a/common/xsync/map.go b/common/xsync/map.go new file mode 100644 index 00000000..198241b9 --- /dev/null +++ b/common/xsync/map.go @@ -0,0 +1,915 @@ +package xsync + +// copy and modified from https://github.com/puzpuzpuz/xsync/blob/v4.1.0/map.go +// which is licensed under Apache v2. +// +// parallel Map resize has been removed to decrease the memory using + +import ( + "fmt" + "math" + "math/bits" + "strings" + "sync" + "sync/atomic" + "unsafe" + + "github.com/metacubex/mihomo/common/maphash" +) + +const ( + // number of Map entries per bucket; 5 entries lead to size of 64B + // (one cache line) on 64-bit machines + entriesPerMapBucket = 5 + // threshold fraction of table occupation to start a table shrinking + // when deleting the last entry in a bucket chain + mapShrinkFraction = 128 + // map load factor to trigger a table resize during insertion; + // a map holds up to mapLoadFactor*entriesPerMapBucket*mapTableLen + // key-value pairs (this is a soft limit) + mapLoadFactor = 0.75 + // minimal table size, i.e. number of buckets; thus, minimal map + // capacity can be calculated as entriesPerMapBucket*defaultMinMapTableLen + defaultMinMapTableLen = 32 + // minimum counter stripes to use + minMapCounterLen = 8 + // maximum counter stripes to use; stands for around 4KB of memory + maxMapCounterLen = 32 + defaultMeta uint64 = 0x8080808080808080 + metaMask uint64 = 0xffffffffff + defaultMetaMasked uint64 = defaultMeta & metaMask + emptyMetaSlot uint8 = 0x80 +) + +type mapResizeHint int + +const ( + mapGrowHint mapResizeHint = 0 + mapShrinkHint mapResizeHint = 1 + mapClearHint mapResizeHint = 2 +) + +type ComputeOp int + +const ( + // CancelOp signals to Compute to not do anything as a result + // of executing the lambda. If the entry was not present in + // the map, nothing happens, and if it was present, the + // returned value is ignored. + CancelOp ComputeOp = iota + // UpdateOp signals to Compute to update the entry to the + // value returned by the lambda, creating it if necessary. + UpdateOp + // DeleteOp signals to Compute to always delete the entry + // from the map. + DeleteOp +) + +type loadOp int + +const ( + noLoadOp loadOp = iota + loadOrComputeOp + loadAndDeleteOp +) + +// Map is like a Go map[K]V but is safe for concurrent +// use by multiple goroutines without additional locking or +// coordination. It follows the interface of sync.Map with +// a number of valuable extensions like Compute or Size. +// +// A Map must not be copied after first use. +// +// Map uses a modified version of Cache-Line Hash Table (CLHT) +// data structure: https://github.com/LPD-EPFL/CLHT +// +// CLHT is built around idea to organize the hash table in +// cache-line-sized buckets, so that on all modern CPUs update +// operations complete with at most one cache-line transfer. +// Also, Get operations involve no write to memory, as well as no +// mutexes or any other sort of locks. Due to this design, in all +// considered scenarios Map outperforms sync.Map. +// +// Map also borrows ideas from Java's j.u.c.ConcurrentHashMap +// (immutable K/V pair structs instead of atomic snapshots) +// and C++'s absl::flat_hash_map (meta memory and SWAR-based +// lookups). +type Map[K comparable, V any] struct { + totalGrowths atomic.Int64 + totalShrinks atomic.Int64 + resizing atomic.Bool // resize in progress flag + resizeMu sync.Mutex // only used along with resizeCond + resizeCond sync.Cond // used to wake up resize waiters (concurrent modifications) + table atomic.Pointer[mapTable[K, V]] + minTableLen int + growOnly bool +} + +type mapTable[K comparable, V any] struct { + buckets []bucketPadded[K, V] + // striped counter for number of table entries; + // used to determine if a table shrinking is needed + // occupies min(buckets_memory/1024, 64KB) of memory + size []counterStripe + seed maphash.Seed +} + +type counterStripe struct { + c int64 + // Padding to prevent false sharing. + _ [cacheLineSize - 8]byte +} + +// bucketPadded is a CL-sized map bucket holding up to +// entriesPerMapBucket entries. +type bucketPadded[K comparable, V any] struct { + //lint:ignore U1000 ensure each bucket takes two cache lines on both 32 and 64-bit archs + pad [cacheLineSize - unsafe.Sizeof(bucket[K, V]{})]byte + bucket[K, V] +} + +type bucket[K comparable, V any] struct { + meta atomic.Uint64 + entries [entriesPerMapBucket]atomic.Pointer[entry[K, V]] // *entry + next atomic.Pointer[bucketPadded[K, V]] // *bucketPadded + mu sync.Mutex +} + +// entry is an immutable map entry. +type entry[K comparable, V any] struct { + key K + value V +} + +// MapConfig defines configurable Map options. +type MapConfig struct { + sizeHint int + growOnly bool +} + +// WithPresize configures new Map instance with capacity enough +// to hold sizeHint entries. The capacity is treated as the minimal +// capacity meaning that the underlying hash table will never shrink +// to a smaller capacity. If sizeHint is zero or negative, the value +// is ignored. +func WithPresize(sizeHint int) func(*MapConfig) { + return func(c *MapConfig) { + c.sizeHint = sizeHint + } +} + +// WithGrowOnly configures new Map instance to be grow-only. +// This means that the underlying hash table grows in capacity when +// new keys are added, but does not shrink when keys are deleted. +// The only exception to this rule is the Clear method which +// shrinks the hash table back to the initial capacity. +func WithGrowOnly() func(*MapConfig) { + return func(c *MapConfig) { + c.growOnly = true + } +} + +// NewMap creates a new Map instance configured with the given +// options. +func NewMap[K comparable, V any](options ...func(*MapConfig)) *Map[K, V] { + c := &MapConfig{ + sizeHint: defaultMinMapTableLen * entriesPerMapBucket, + } + for _, o := range options { + o(c) + } + + m := &Map[K, V]{} + m.resizeCond = *sync.NewCond(&m.resizeMu) + var table *mapTable[K, V] + if c.sizeHint <= defaultMinMapTableLen*entriesPerMapBucket { + table = newMapTable[K, V](defaultMinMapTableLen) + } else { + tableLen := nextPowOf2(uint32((float64(c.sizeHint) / entriesPerMapBucket) / mapLoadFactor)) + table = newMapTable[K, V](int(tableLen)) + } + m.minTableLen = len(table.buckets) + m.growOnly = c.growOnly + m.table.Store(table) + return m +} + +func newMapTable[K comparable, V any](minTableLen int) *mapTable[K, V] { + buckets := make([]bucketPadded[K, V], minTableLen) + for i := range buckets { + buckets[i].meta.Store(defaultMeta) + } + counterLen := minTableLen >> 10 + if counterLen < minMapCounterLen { + counterLen = minMapCounterLen + } else if counterLen > maxMapCounterLen { + counterLen = maxMapCounterLen + } + counter := make([]counterStripe, counterLen) + t := &mapTable[K, V]{ + buckets: buckets, + size: counter, + seed: maphash.MakeSeed(), + } + return t +} + +// ToPlainMap returns a native map with a copy of xsync Map's +// contents. The copied xsync Map should not be modified while +// this call is made. If the copied Map is modified, the copying +// behavior is the same as in the Range method. +func ToPlainMap[K comparable, V any](m *Map[K, V]) map[K]V { + pm := make(map[K]V) + if m != nil { + m.Range(func(key K, value V) bool { + pm[key] = value + return true + }) + } + return pm +} + +// Load returns the value stored in the map for a key, or zero value +// of type V if no value is present. +// The ok result indicates whether value was found in the map. +func (m *Map[K, V]) Load(key K) (value V, ok bool) { + table := m.table.Load() + hash := maphash.Comparable(table.seed, key) + h1 := h1(hash) + h2w := broadcast(h2(hash)) + bidx := uint64(len(table.buckets)-1) & h1 + b := &table.buckets[bidx] + for { + metaw := b.meta.Load() + markedw := markZeroBytes(metaw^h2w) & metaMask + for markedw != 0 { + idx := firstMarkedByteIndex(markedw) + e := b.entries[idx].Load() + if e != nil { + if e.key == key { + return e.value, true + } + } + markedw &= markedw - 1 + } + b = b.next.Load() + if b == nil { + return + } + } +} + +// Store sets the value for a key. +func (m *Map[K, V]) Store(key K, value V) { + m.doCompute( + key, + func(V, bool) (V, ComputeOp) { + return value, UpdateOp + }, + noLoadOp, + false, + ) +} + +// LoadOrStore returns the existing value for the key if present. +// Otherwise, it stores and returns the given value. +// The loaded result is true if the value was loaded, false if stored. +func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { + return m.doCompute( + key, + func(oldValue V, loaded bool) (V, ComputeOp) { + if loaded { + return oldValue, CancelOp + } + return value, UpdateOp + }, + loadOrComputeOp, + false, + ) +} + +// LoadAndStore returns the existing value for the key if present, +// while setting the new value for the key. +// It stores the new value and returns the existing one, if present. +// The loaded result is true if the existing value was loaded, +// false otherwise. +func (m *Map[K, V]) LoadAndStore(key K, value V) (actual V, loaded bool) { + return m.doCompute( + key, + func(V, bool) (V, ComputeOp) { + return value, UpdateOp + }, + noLoadOp, + false, + ) +} + +// LoadOrCompute returns the existing value for the key if +// present. Otherwise, it tries to compute the value using the +// provided function and, if successful, stores and returns +// the computed value. The loaded result is true if the value was +// loaded, or false if computed. If valueFn returns true as the +// cancel value, the computation is cancelled and the zero value +// for type V is returned. +// +// This call locks a hash table bucket while the compute function +// is executed. It means that modifications on other entries in +// the bucket will be blocked until the valueFn executes. Consider +// this when the function includes long-running operations. +func (m *Map[K, V]) LoadOrCompute( + key K, + valueFn func() (newValue V, cancel bool), +) (value V, loaded bool) { + return m.doCompute( + key, + func(oldValue V, loaded bool) (V, ComputeOp) { + if loaded { + return oldValue, CancelOp + } + newValue, c := valueFn() + if !c { + return newValue, UpdateOp + } + return oldValue, CancelOp + }, + loadOrComputeOp, + false, + ) +} + +// Compute either sets the computed new value for the key, +// deletes the value for the key, or does nothing, based on +// the returned [ComputeOp]. When the op returned by valueFn +// is [UpdateOp], the value is updated to the new value. If +// it is [DeleteOp], the entry is removed from the map +// altogether. And finally, if the op is [CancelOp] then the +// entry is left as-is. In other words, if it did not already +// exist, it is not created, and if it did exist, it is not +// updated. This is useful to synchronously execute some +// operation on the value without incurring the cost of +// updating the map every time. The ok result indicates +// whether the entry is present in the map after the compute +// operation. The actual result contains the value of the map +// if a corresponding entry is present, or the zero value +// otherwise. See the example for a few use cases. +// +// This call locks a hash table bucket while the compute function +// is executed. It means that modifications on other entries in +// the bucket will be blocked until the valueFn executes. Consider +// this when the function includes long-running operations. +func (m *Map[K, V]) Compute( + key K, + valueFn func(oldValue V, loaded bool) (newValue V, op ComputeOp), +) (actual V, ok bool) { + return m.doCompute(key, valueFn, noLoadOp, true) +} + +// LoadAndDelete deletes the value for a key, returning the previous +// value if any. The loaded result reports whether the key was +// present. +func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) { + return m.doCompute( + key, + func(value V, loaded bool) (V, ComputeOp) { + return value, DeleteOp + }, + loadAndDeleteOp, + false, + ) +} + +// Delete deletes the value for a key. +func (m *Map[K, V]) Delete(key K) { + m.LoadAndDelete(key) +} + +func (m *Map[K, V]) doCompute( + key K, + valueFn func(oldValue V, loaded bool) (V, ComputeOp), + loadOp loadOp, + computeOnly bool, +) (V, bool) { + for { + compute_attempt: + var ( + emptyb *bucketPadded[K, V] + emptyidx int + ) + table := m.table.Load() + tableLen := len(table.buckets) + hash := maphash.Comparable(table.seed, key) + h1 := h1(hash) + h2 := h2(hash) + h2w := broadcast(h2) + bidx := uint64(len(table.buckets)-1) & h1 + rootb := &table.buckets[bidx] + + if loadOp != noLoadOp { + b := rootb + load: + for { + metaw := b.meta.Load() + markedw := markZeroBytes(metaw^h2w) & metaMask + for markedw != 0 { + idx := firstMarkedByteIndex(markedw) + e := b.entries[idx].Load() + if e != nil { + if e.key == key { + if loadOp == loadOrComputeOp { + return e.value, true + } + break load + } + } + markedw &= markedw - 1 + } + b = b.next.Load() + if b == nil { + if loadOp == loadAndDeleteOp { + return *new(V), false + } + break load + } + } + } + + rootb.mu.Lock() + // The following two checks must go in reverse to what's + // in the resize method. + if m.resizeInProgress() { + // Resize is in progress. Wait, then go for another attempt. + rootb.mu.Unlock() + m.waitForResize() + goto compute_attempt + } + if m.newerTableExists(table) { + // Someone resized the table. Go for another attempt. + rootb.mu.Unlock() + goto compute_attempt + } + b := rootb + for { + metaw := b.meta.Load() + markedw := markZeroBytes(metaw^h2w) & metaMask + for markedw != 0 { + idx := firstMarkedByteIndex(markedw) + e := b.entries[idx].Load() + if e != nil { + if e.key == key { + // In-place update/delete. + // We get a copy of the value via an interface{} on each call, + // thus the live value pointers are unique. Otherwise atomic + // snapshot won't be correct in case of multiple Store calls + // using the same value. + oldv := e.value + newv, op := valueFn(oldv, true) + switch op { + case DeleteOp: + // Deletion. + // First we update the hash, then the entry. + newmetaw := setByte(metaw, emptyMetaSlot, idx) + b.meta.Store(newmetaw) + b.entries[idx].Store(nil) + rootb.mu.Unlock() + table.addSize(bidx, -1) + // Might need to shrink the table if we left bucket empty. + if newmetaw == defaultMeta { + m.resize(table, mapShrinkHint) + } + return oldv, !computeOnly + case UpdateOp: + newe := new(entry[K, V]) + newe.key = key + newe.value = newv + b.entries[idx].Store(newe) + case CancelOp: + newv = oldv + } + rootb.mu.Unlock() + if computeOnly { + // Compute expects the new value to be returned. + return newv, true + } + // LoadAndStore expects the old value to be returned. + return oldv, true + } + } + markedw &= markedw - 1 + } + if emptyb == nil { + // Search for empty entries (up to 5 per bucket). + emptyw := metaw & defaultMetaMasked + if emptyw != 0 { + idx := firstMarkedByteIndex(emptyw) + emptyb = b + emptyidx = idx + } + } + if b.next.Load() == nil { + if emptyb != nil { + // Insertion into an existing bucket. + var zeroV V + newValue, op := valueFn(zeroV, false) + switch op { + case DeleteOp, CancelOp: + rootb.mu.Unlock() + return zeroV, false + default: + newe := new(entry[K, V]) + newe.key = key + newe.value = newValue + // First we update meta, then the entry. + emptyb.meta.Store(setByte(emptyb.meta.Load(), h2, emptyidx)) + emptyb.entries[emptyidx].Store(newe) + rootb.mu.Unlock() + table.addSize(bidx, 1) + return newValue, computeOnly + } + } + growThreshold := float64(tableLen) * entriesPerMapBucket * mapLoadFactor + if table.sumSize() > int64(growThreshold) { + // Need to grow the table. Then go for another attempt. + rootb.mu.Unlock() + m.resize(table, mapGrowHint) + goto compute_attempt + } + // Insertion into a new bucket. + var zeroV V + newValue, op := valueFn(zeroV, false) + switch op { + case DeleteOp, CancelOp: + rootb.mu.Unlock() + return newValue, false + default: + // Create and append a bucket. + newb := new(bucketPadded[K, V]) + newb.meta.Store(setByte(defaultMeta, h2, 0)) + newe := new(entry[K, V]) + newe.key = key + newe.value = newValue + newb.entries[0].Store(newe) + b.next.Store(newb) + rootb.mu.Unlock() + table.addSize(bidx, 1) + return newValue, computeOnly + } + } + b = b.next.Load() + } + } +} + +func (m *Map[K, V]) newerTableExists(table *mapTable[K, V]) bool { + return table != m.table.Load() +} + +func (m *Map[K, V]) resizeInProgress() bool { + return m.resizing.Load() +} + +func (m *Map[K, V]) waitForResize() { + m.resizeMu.Lock() + for m.resizeInProgress() { + m.resizeCond.Wait() + } + m.resizeMu.Unlock() +} + +func (m *Map[K, V]) resize(knownTable *mapTable[K, V], hint mapResizeHint) { + knownTableLen := len(knownTable.buckets) + // Fast path for shrink attempts. + if hint == mapShrinkHint { + if m.growOnly || + m.minTableLen == knownTableLen || + knownTable.sumSize() > int64((knownTableLen*entriesPerMapBucket)/mapShrinkFraction) { + return + } + } + // Slow path. + if !m.resizing.CompareAndSwap(false, true) { + // Someone else started resize. Wait for it to finish. + m.waitForResize() + return + } + var newTable *mapTable[K, V] + table := m.table.Load() + tableLen := len(table.buckets) + switch hint { + case mapGrowHint: + // Grow the table with factor of 2. + m.totalGrowths.Add(1) + newTable = newMapTable[K, V](tableLen << 1) + case mapShrinkHint: + shrinkThreshold := int64((tableLen * entriesPerMapBucket) / mapShrinkFraction) + if tableLen > m.minTableLen && table.sumSize() <= shrinkThreshold { + // Shrink the table with factor of 2. + m.totalShrinks.Add(1) + newTable = newMapTable[K, V](tableLen >> 1) + } else { + // No need to shrink. Wake up all waiters and give up. + m.resizeMu.Lock() + m.resizing.Store(false) + m.resizeCond.Broadcast() + m.resizeMu.Unlock() + return + } + case mapClearHint: + newTable = newMapTable[K, V](m.minTableLen) + default: + panic(fmt.Sprintf("unexpected resize hint: %d", hint)) + } + // Copy the data only if we're not clearing the map. + if hint != mapClearHint { + for i := 0; i < tableLen; i++ { + copied := copyBucket(&table.buckets[i], newTable) + newTable.addSizePlain(uint64(i), copied) + } + } + // Publish the new table and wake up all waiters. + m.table.Store(newTable) + m.resizeMu.Lock() + m.resizing.Store(false) + m.resizeCond.Broadcast() + m.resizeMu.Unlock() +} + +func copyBucket[K comparable, V any]( + b *bucketPadded[K, V], + destTable *mapTable[K, V], +) (copied int) { + rootb := b + rootb.mu.Lock() + for { + for i := 0; i < entriesPerMapBucket; i++ { + if e := b.entries[i].Load(); e != nil { + hash := maphash.Comparable(destTable.seed, e.key) + bidx := uint64(len(destTable.buckets)-1) & h1(hash) + destb := &destTable.buckets[bidx] + appendToBucket(h2(hash), b.entries[i].Load(), destb) + copied++ + } + } + if next := b.next.Load(); next == nil { + rootb.mu.Unlock() + return + } else { + b = next + } + } +} + +// Range calls f sequentially for each key and value present in the +// map. If f returns false, range stops the iteration. +// +// Range does not necessarily correspond to any consistent snapshot +// of the Map's contents: no key will be visited more than once, but +// if the value for any key is stored or deleted concurrently, Range +// may reflect any mapping for that key from any point during the +// Range call. +// +// It is safe to modify the map while iterating it, including entry +// creation, modification and deletion. However, the concurrent +// modification rule apply, i.e. the changes may be not reflected +// in the subsequently iterated entries. +func (m *Map[K, V]) Range(f func(key K, value V) bool) { + // Pre-allocate array big enough to fit entries for most hash tables. + bentries := make([]*entry[K, V], 0, 16*entriesPerMapBucket) + table := m.table.Load() + for i := range table.buckets { + rootb := &table.buckets[i] + b := rootb + // Prevent concurrent modifications and copy all entries into + // the intermediate slice. + rootb.mu.Lock() + for { + for i := 0; i < entriesPerMapBucket; i++ { + if entry := b.entries[i].Load(); entry != nil { + bentries = append(bentries, entry) + } + } + if next := b.next.Load(); next == nil { + rootb.mu.Unlock() + break + } else { + b = next + } + } + // Call the function for all copied entries. + for j, e := range bentries { + if !f(e.key, e.value) { + return + } + // Remove the reference to avoid preventing the copied + // entries from being GCed until this method finishes. + bentries[j] = nil + } + bentries = bentries[:0] + } +} + +// Clear deletes all keys and values currently stored in the map. +func (m *Map[K, V]) Clear() { + m.resize(m.table.Load(), mapClearHint) +} + +// Size returns current size of the map. +func (m *Map[K, V]) Size() int { + return int(m.table.Load().sumSize()) +} + +func appendToBucket[K comparable, V any](h2 uint8, e *entry[K, V], b *bucketPadded[K, V]) { + for { + for i := 0; i < entriesPerMapBucket; i++ { + if b.entries[i].Load() == nil { + b.meta.Store(setByte(b.meta.Load(), h2, i)) + b.entries[i].Store(e) + return + } + } + if next := b.next.Load(); next == nil { + newb := new(bucketPadded[K, V]) + newb.meta.Store(setByte(defaultMeta, h2, 0)) + newb.entries[0].Store(e) + b.next.Store(newb) + return + } else { + b = next + } + } +} + +func (table *mapTable[K, V]) addSize(bucketIdx uint64, delta int) { + cidx := uint64(len(table.size)-1) & bucketIdx + atomic.AddInt64(&table.size[cidx].c, int64(delta)) +} + +func (table *mapTable[K, V]) addSizePlain(bucketIdx uint64, delta int) { + cidx := uint64(len(table.size)-1) & bucketIdx + table.size[cidx].c += int64(delta) +} + +func (table *mapTable[K, V]) sumSize() int64 { + sum := int64(0) + for i := range table.size { + sum += atomic.LoadInt64(&table.size[i].c) + } + return sum +} + +func h1(h uint64) uint64 { + return h >> 7 +} + +func h2(h uint64) uint8 { + return uint8(h & 0x7f) +} + +// MapStats is Map statistics. +// +// Warning: map statistics are intented to be used for diagnostic +// purposes, not for production code. This means that breaking changes +// may be introduced into this struct even between minor releases. +type MapStats struct { + // RootBuckets is the number of root buckets in the hash table. + // Each bucket holds a few entries. + RootBuckets int + // TotalBuckets is the total number of buckets in the hash table, + // including root and their chained buckets. Each bucket holds + // a few entries. + TotalBuckets int + // EmptyBuckets is the number of buckets that hold no entries. + EmptyBuckets int + // Capacity is the Map capacity, i.e. the total number of + // entries that all buckets can physically hold. This number + // does not consider the load factor. + Capacity int + // Size is the exact number of entries stored in the map. + Size int + // Counter is the number of entries stored in the map according + // to the internal atomic counter. In case of concurrent map + // modifications this number may be different from Size. + Counter int + // CounterLen is the number of internal atomic counter stripes. + // This number may grow with the map capacity to improve + // multithreaded scalability. + CounterLen int + // MinEntries is the minimum number of entries per a chain of + // buckets, i.e. a root bucket and its chained buckets. + MinEntries int + // MinEntries is the maximum number of entries per a chain of + // buckets, i.e. a root bucket and its chained buckets. + MaxEntries int + // TotalGrowths is the number of times the hash table grew. + TotalGrowths int64 + // TotalGrowths is the number of times the hash table shrinked. + TotalShrinks int64 +} + +// ToString returns string representation of map stats. +func (s *MapStats) ToString() string { + var sb strings.Builder + sb.WriteString("MapStats{\n") + sb.WriteString(fmt.Sprintf("RootBuckets: %d\n", s.RootBuckets)) + sb.WriteString(fmt.Sprintf("TotalBuckets: %d\n", s.TotalBuckets)) + sb.WriteString(fmt.Sprintf("EmptyBuckets: %d\n", s.EmptyBuckets)) + sb.WriteString(fmt.Sprintf("Capacity: %d\n", s.Capacity)) + sb.WriteString(fmt.Sprintf("Size: %d\n", s.Size)) + sb.WriteString(fmt.Sprintf("Counter: %d\n", s.Counter)) + sb.WriteString(fmt.Sprintf("CounterLen: %d\n", s.CounterLen)) + sb.WriteString(fmt.Sprintf("MinEntries: %d\n", s.MinEntries)) + sb.WriteString(fmt.Sprintf("MaxEntries: %d\n", s.MaxEntries)) + sb.WriteString(fmt.Sprintf("TotalGrowths: %d\n", s.TotalGrowths)) + sb.WriteString(fmt.Sprintf("TotalShrinks: %d\n", s.TotalShrinks)) + sb.WriteString("}\n") + return sb.String() +} + +// Stats returns statistics for the Map. Just like other map +// methods, this one is thread-safe. Yet it's an O(N) operation, +// so it should be used only for diagnostics or debugging purposes. +func (m *Map[K, V]) Stats() MapStats { + stats := MapStats{ + TotalGrowths: m.totalGrowths.Load(), + TotalShrinks: m.totalShrinks.Load(), + MinEntries: math.MaxInt32, + } + table := m.table.Load() + stats.RootBuckets = len(table.buckets) + stats.Counter = int(table.sumSize()) + stats.CounterLen = len(table.size) + for i := range table.buckets { + nentries := 0 + b := &table.buckets[i] + stats.TotalBuckets++ + for { + nentriesLocal := 0 + stats.Capacity += entriesPerMapBucket + for i := 0; i < entriesPerMapBucket; i++ { + if b.entries[i].Load() != nil { + stats.Size++ + nentriesLocal++ + } + } + nentries += nentriesLocal + if nentriesLocal == 0 { + stats.EmptyBuckets++ + } + if next := b.next.Load(); next == nil { + break + } else { + b = next + } + stats.TotalBuckets++ + } + if nentries < stats.MinEntries { + stats.MinEntries = nentries + } + if nentries > stats.MaxEntries { + stats.MaxEntries = nentries + } + } + return stats +} + +const ( + // cacheLineSize is used in paddings to prevent false sharing; + // 64B are used instead of 128B as a compromise between + // memory footprint and performance; 128B usage may give ~30% + // improvement on NUMA machines. + cacheLineSize = 64 +) + +// nextPowOf2 computes the next highest power of 2 of 32-bit v. +// Source: https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 +func nextPowOf2(v uint32) uint32 { + if v == 0 { + return 1 + } + v-- + v |= v >> 1 + v |= v >> 2 + v |= v >> 4 + v |= v >> 8 + v |= v >> 16 + v++ + return v +} + +func broadcast(b uint8) uint64 { + return 0x101010101010101 * uint64(b) +} + +func firstMarkedByteIndex(w uint64) int { + return bits.TrailingZeros64(w) >> 3 +} + +// SWAR byte search: may produce false positives, e.g. for 0x0100, +// so make sure to double-check bytes found by this function. +func markZeroBytes(w uint64) uint64 { + return ((w - 0x0101010101010101) & (^w) & 0x8080808080808080) +} + +func setByte(w uint64, b uint8, idx int) uint64 { + shift := idx << 3 + return (w &^ (0xff << shift)) | (uint64(b) << shift) +} diff --git a/common/xsync/map_extra.go b/common/xsync/map_extra.go new file mode 100644 index 00000000..8b62af60 --- /dev/null +++ b/common/xsync/map_extra.go @@ -0,0 +1,28 @@ +package xsync + +// LoadOrStoreFn returns the existing value for the key if +// present. Otherwise, it tries to compute the value using the +// provided function and, if successful, stores and returns +// the computed value. The loaded result is true if the value was +// loaded, or false if computed. +// +// This call locks a hash table bucket while the compute function +// is executed. It means that modifications on other entries in +// the bucket will be blocked until the valueFn executes. Consider +// this when the function includes long-running operations. +// +// Recovery this API and renamed from xsync/v3's LoadOrCompute. +// We unneeded support no-op (cancel) compute operation, it will only add complexity to existing code. +func (m *Map[K, V]) LoadOrStoreFn(key K, valueFn func() V) (actual V, loaded bool) { + return m.doCompute( + key, + func(oldValue V, loaded bool) (V, ComputeOp) { + if loaded { + return oldValue, CancelOp + } + return valueFn(), UpdateOp + }, + loadOrComputeOp, + false, + ) +} diff --git a/common/xsync/map_extra_test.go b/common/xsync/map_extra_test.go new file mode 100644 index 00000000..b8938b8b --- /dev/null +++ b/common/xsync/map_extra_test.go @@ -0,0 +1,49 @@ +package xsync + +import ( + "strconv" + "testing" +) + +func TestMapOfLoadOrStoreFn(t *testing.T) { + const numEntries = 1000 + m := NewMap[string, int]() + for i := 0; i < numEntries; i++ { + v, loaded := m.LoadOrStoreFn(strconv.Itoa(i), func() int { + return i + }) + if loaded { + t.Fatalf("value not computed for %d", i) + } + if v != i { + t.Fatalf("values do not match for %d: %v", i, v) + } + } + for i := 0; i < numEntries; i++ { + v, loaded := m.LoadOrStoreFn(strconv.Itoa(i), func() int { + return i + }) + if !loaded { + t.Fatalf("value not loaded for %d", i) + } + if v != i { + t.Fatalf("values do not match for %d: %v", i, v) + } + } +} + +func TestMapOfLoadOrStoreFn_FunctionCalledOnce(t *testing.T) { + m := NewMap[int, int]() + for i := 0; i < 100; { + m.LoadOrStoreFn(i, func() (v int) { + v, i = i, i+1 + return v + }) + } + m.Range(func(k, v int) bool { + if k != v { + t.Fatalf("%dth key is not equal to value %d", k, v) + } + return true + }) +} diff --git a/common/xsync/map_test.go b/common/xsync/map_test.go new file mode 100644 index 00000000..b40d412b --- /dev/null +++ b/common/xsync/map_test.go @@ -0,0 +1,1732 @@ +package xsync + +import ( + "math" + "math/rand" + "strconv" + "sync" + "sync/atomic" + "testing" + "time" + "unsafe" + + "github.com/metacubex/randv2" +) + +const ( + // number of entries to use in benchmarks + benchmarkNumEntries = 1_000 + // key prefix used in benchmarks + benchmarkKeyPrefix = "what_a_looooooooooooooooooooooong_key_prefix_" +) + +type point struct { + x int32 + y int32 +} + +var benchmarkCases = []struct { + name string + readPercentage int +}{ + {"reads=100%", 100}, // 100% loads, 0% stores, 0% deletes + {"reads=99%", 99}, // 99% loads, 0.5% stores, 0.5% deletes + {"reads=90%", 90}, // 90% loads, 5% stores, 5% deletes + {"reads=75%", 75}, // 75% loads, 12.5% stores, 12.5% deletes +} + +var benchmarkKeys []string + +func init() { + benchmarkKeys = make([]string, benchmarkNumEntries) + for i := 0; i < benchmarkNumEntries; i++ { + benchmarkKeys[i] = benchmarkKeyPrefix + strconv.Itoa(i) + } +} + +func runParallel(b *testing.B, benchFn func(pb *testing.PB)) { + b.ResetTimer() + start := time.Now() + b.RunParallel(benchFn) + opsPerSec := float64(b.N) / float64(time.Since(start).Seconds()) + b.ReportMetric(opsPerSec, "ops/s") +} + +func TestMap_BucketStructSize(t *testing.T) { + size := unsafe.Sizeof(bucketPadded[string, int64]{}) + if size != 64 { + t.Fatalf("size of 64B (one cache line) is expected, got: %d", size) + } + size = unsafe.Sizeof(bucketPadded[struct{}, int32]{}) + if size != 64 { + t.Fatalf("size of 64B (one cache line) is expected, got: %d", size) + } +} + +func TestMap_MissingEntry(t *testing.T) { + m := NewMap[string, string]() + v, ok := m.Load("foo") + if ok { + t.Fatalf("value was not expected: %v", v) + } + if deleted, loaded := m.LoadAndDelete("foo"); loaded { + t.Fatalf("value was not expected %v", deleted) + } + if actual, loaded := m.LoadOrStore("foo", "bar"); loaded { + t.Fatalf("value was not expected %v", actual) + } +} + +func TestMap_EmptyStringKey(t *testing.T) { + m := NewMap[string, string]() + m.Store("", "foobar") + v, ok := m.Load("") + if !ok { + t.Fatal("value was expected") + } + if v != "foobar" { + t.Fatalf("value does not match: %v", v) + } +} + +func TestMapStore_NilValue(t *testing.T) { + m := NewMap[string, *struct{}]() + m.Store("foo", nil) + v, ok := m.Load("foo") + if !ok { + t.Fatal("nil value was expected") + } + if v != nil { + t.Fatalf("value was not nil: %v", v) + } +} + +func TestMapLoadOrStore_NilValue(t *testing.T) { + m := NewMap[string, *struct{}]() + m.LoadOrStore("foo", nil) + v, loaded := m.LoadOrStore("foo", nil) + if !loaded { + t.Fatal("nil value was expected") + } + if v != nil { + t.Fatalf("value was not nil: %v", v) + } +} + +func TestMapLoadOrStore_NonNilValue(t *testing.T) { + type foo struct{} + m := NewMap[string, *foo]() + newv := &foo{} + v, loaded := m.LoadOrStore("foo", newv) + if loaded { + t.Fatal("no value was expected") + } + if v != newv { + t.Fatalf("value does not match: %v", v) + } + newv2 := &foo{} + v, loaded = m.LoadOrStore("foo", newv2) + if !loaded { + t.Fatal("value was expected") + } + if v != newv { + t.Fatalf("value does not match: %v", v) + } +} + +func TestMapLoadAndStore_NilValue(t *testing.T) { + m := NewMap[string, *struct{}]() + m.LoadAndStore("foo", nil) + v, loaded := m.LoadAndStore("foo", nil) + if !loaded { + t.Fatal("nil value was expected") + } + if v != nil { + t.Fatalf("value was not nil: %v", v) + } + v, loaded = m.Load("foo") + if !loaded { + t.Fatal("nil value was expected") + } + if v != nil { + t.Fatalf("value was not nil: %v", v) + } +} + +func TestMapLoadAndStore_NonNilValue(t *testing.T) { + m := NewMap[string, int]() + v1 := 1 + v, loaded := m.LoadAndStore("foo", v1) + if loaded { + t.Fatal("no value was expected") + } + if v != v1 { + t.Fatalf("value does not match: %v", v) + } + v2 := 2 + v, loaded = m.LoadAndStore("foo", v2) + if !loaded { + t.Fatal("value was expected") + } + if v != v1 { + t.Fatalf("value does not match: %v", v) + } + v, loaded = m.Load("foo") + if !loaded { + t.Fatal("value was expected") + } + if v != v2 { + t.Fatalf("value does not match: %v", v) + } +} + +func TestMapRange(t *testing.T) { + const numEntries = 1000 + m := NewMap[string, int]() + for i := 0; i < numEntries; i++ { + m.Store(strconv.Itoa(i), i) + } + iters := 0 + met := make(map[string]int) + m.Range(func(key string, value int) bool { + if key != strconv.Itoa(value) { + t.Fatalf("got unexpected key/value for iteration %d: %v/%v", iters, key, value) + return false + } + met[key] += 1 + iters++ + return true + }) + if iters != numEntries { + t.Fatalf("got unexpected number of iterations: %d", iters) + } + for i := 0; i < numEntries; i++ { + if c := met[strconv.Itoa(i)]; c != 1 { + t.Fatalf("range did not iterate correctly over %d: %d", i, c) + } + } +} + +func TestMapRange_FalseReturned(t *testing.T) { + m := NewMap[string, int]() + for i := 0; i < 100; i++ { + m.Store(strconv.Itoa(i), i) + } + iters := 0 + m.Range(func(key string, value int) bool { + iters++ + return iters != 13 + }) + if iters != 13 { + t.Fatalf("got unexpected number of iterations: %d", iters) + } +} + +func TestMapRange_NestedDelete(t *testing.T) { + const numEntries = 256 + m := NewMap[string, int]() + for i := 0; i < numEntries; i++ { + m.Store(strconv.Itoa(i), i) + } + m.Range(func(key string, value int) bool { + m.Delete(key) + return true + }) + for i := 0; i < numEntries; i++ { + if _, ok := m.Load(strconv.Itoa(i)); ok { + t.Fatalf("value found for %d", i) + } + } +} + +func TestMapStringStore(t *testing.T) { + const numEntries = 128 + m := NewMap[string, int]() + for i := 0; i < numEntries; i++ { + m.Store(strconv.Itoa(i), i) + } + for i := 0; i < numEntries; i++ { + v, ok := m.Load(strconv.Itoa(i)) + if !ok { + t.Fatalf("value not found for %d", i) + } + if v != i { + t.Fatalf("values do not match for %d: %v", i, v) + } + } +} + +func TestMapIntStore(t *testing.T) { + const numEntries = 128 + m := NewMap[int, int]() + for i := 0; i < numEntries; i++ { + m.Store(i, i) + } + for i := 0; i < numEntries; i++ { + v, ok := m.Load(i) + if !ok { + t.Fatalf("value not found for %d", i) + } + if v != i { + t.Fatalf("values do not match for %d: %v", i, v) + } + } +} + +func TestMapStore_StructKeys_IntValues(t *testing.T) { + const numEntries = 128 + m := NewMap[point, int]() + for i := 0; i < numEntries; i++ { + m.Store(point{int32(i), -int32(i)}, i) + } + for i := 0; i < numEntries; i++ { + v, ok := m.Load(point{int32(i), -int32(i)}) + if !ok { + t.Fatalf("value not found for %d", i) + } + if v != i { + t.Fatalf("values do not match for %d: %v", i, v) + } + } +} + +func TestMapStore_StructKeys_StructValues(t *testing.T) { + const numEntries = 128 + m := NewMap[point, point]() + for i := 0; i < numEntries; i++ { + m.Store(point{int32(i), -int32(i)}, point{-int32(i), int32(i)}) + } + for i := 0; i < numEntries; i++ { + v, ok := m.Load(point{int32(i), -int32(i)}) + if !ok { + t.Fatalf("value not found for %d", i) + } + if v.x != -int32(i) { + t.Fatalf("x value does not match for %d: %v", i, v) + } + if v.y != int32(i) { + t.Fatalf("y value does not match for %d: %v", i, v) + } + } +} + +func TestMapLoadOrStore(t *testing.T) { + const numEntries = 1000 + m := NewMap[string, int]() + for i := 0; i < numEntries; i++ { + m.Store(strconv.Itoa(i), i) + } + for i := 0; i < numEntries; i++ { + if _, loaded := m.LoadOrStore(strconv.Itoa(i), i); !loaded { + t.Fatalf("value not found for %d", i) + } + } +} + +func TestMapLoadOrCompute(t *testing.T) { + const numEntries = 1000 + m := NewMap[string, int]() + for i := 0; i < numEntries; i++ { + v, loaded := m.LoadOrCompute(strconv.Itoa(i), func() (newValue int, cancel bool) { + return i, true + }) + if loaded { + t.Fatalf("value not computed for %d", i) + } + if v != 0 { + t.Fatalf("values do not match for %d: %v", i, v) + } + } + if m.Size() != 0 { + t.Fatalf("zero map size expected: %d", m.Size()) + } + for i := 0; i < numEntries; i++ { + v, loaded := m.LoadOrCompute(strconv.Itoa(i), func() (newValue int, cancel bool) { + return i, false + }) + if loaded { + t.Fatalf("value not computed for %d", i) + } + if v != i { + t.Fatalf("values do not match for %d: %v", i, v) + } + } + for i := 0; i < numEntries; i++ { + v, loaded := m.LoadOrCompute(strconv.Itoa(i), func() (newValue int, cancel bool) { + t.Fatalf("value func invoked") + return newValue, false + }) + if !loaded { + t.Fatalf("value not loaded for %d", i) + } + if v != i { + t.Fatalf("values do not match for %d: %v", i, v) + } + } +} + +func TestMapLoadOrCompute_FunctionCalledOnce(t *testing.T) { + m := NewMap[int, int]() + for i := 0; i < 100; { + m.LoadOrCompute(i, func() (newValue int, cancel bool) { + newValue, i = i, i+1 + return newValue, false + }) + } + m.Range(func(k, v int) bool { + if k != v { + t.Fatalf("%dth key is not equal to value %d", k, v) + } + return true + }) +} + +func TestMapOfCompute(t *testing.T) { + m := NewMap[string, int]() + // Store a new value. + v, ok := m.Compute("foobar", func(oldValue int, loaded bool) (newValue int, op ComputeOp) { + if oldValue != 0 { + t.Fatalf("oldValue should be 0 when computing a new value: %d", oldValue) + } + if loaded { + t.Fatal("loaded should be false when computing a new value") + } + newValue = 42 + op = UpdateOp + return + }) + if v != 42 { + t.Fatalf("v should be 42 when computing a new value: %d", v) + } + if !ok { + t.Fatal("ok should be true when computing a new value") + } + // Update an existing value. + v, ok = m.Compute("foobar", func(oldValue int, loaded bool) (newValue int, op ComputeOp) { + if oldValue != 42 { + t.Fatalf("oldValue should be 42 when updating the value: %d", oldValue) + } + if !loaded { + t.Fatal("loaded should be true when updating the value") + } + newValue = oldValue + 42 + op = UpdateOp + return + }) + if v != 84 { + t.Fatalf("v should be 84 when updating the value: %d", v) + } + if !ok { + t.Fatal("ok should be true when updating the value") + } + // Check that NoOp doesn't update the value + v, ok = m.Compute("foobar", func(oldValue int, loaded bool) (newValue int, op ComputeOp) { + return 0, CancelOp + }) + if v != 84 { + t.Fatalf("v should be 84 after using NoOp: %d", v) + } + if !ok { + t.Fatal("ok should be true when updating the value") + } + // Delete an existing value. + v, ok = m.Compute("foobar", func(oldValue int, loaded bool) (newValue int, op ComputeOp) { + if oldValue != 84 { + t.Fatalf("oldValue should be 84 when deleting the value: %d", oldValue) + } + if !loaded { + t.Fatal("loaded should be true when deleting the value") + } + op = DeleteOp + return + }) + if v != 84 { + t.Fatalf("v should be 84 when deleting the value: %d", v) + } + if ok { + t.Fatal("ok should be false when deleting the value") + } + // Try to delete a non-existing value. Notice different key. + v, ok = m.Compute("barbaz", func(oldValue int, loaded bool) (newValue int, op ComputeOp) { + if oldValue != 0 { + t.Fatalf("oldValue should be 0 when trying to delete a non-existing value: %d", oldValue) + } + if loaded { + t.Fatal("loaded should be false when trying to delete a non-existing value") + } + // We're returning a non-zero value, but the map should ignore it. + newValue = 42 + op = DeleteOp + return + }) + if v != 0 { + t.Fatalf("v should be 0 when trying to delete a non-existing value: %d", v) + } + if ok { + t.Fatal("ok should be false when trying to delete a non-existing value") + } + // Try NoOp on a non-existing value + v, ok = m.Compute("barbaz", func(oldValue int, loaded bool) (newValue int, op ComputeOp) { + if oldValue != 0 { + t.Fatalf("oldValue should be 0 when trying to delete a non-existing value: %d", oldValue) + } + if loaded { + t.Fatal("loaded should be false when trying to delete a non-existing value") + } + // We're returning a non-zero value, but the map should ignore it. + newValue = 42 + op = CancelOp + return + }) + if v != 0 { + t.Fatalf("v should be 0 when trying to delete a non-existing value: %d", v) + } + if ok { + t.Fatal("ok should be false when trying to delete a non-existing value") + } +} + +func TestMapStringStoreThenDelete(t *testing.T) { + const numEntries = 1000 + m := NewMap[string, int]() + for i := 0; i < numEntries; i++ { + m.Store(strconv.Itoa(i), i) + } + for i := 0; i < numEntries; i++ { + m.Delete(strconv.Itoa(i)) + if _, ok := m.Load(strconv.Itoa(i)); ok { + t.Fatalf("value was not expected for %d", i) + } + } +} + +func TestMapIntStoreThenDelete(t *testing.T) { + const numEntries = 1000 + m := NewMap[int32, int32]() + for i := 0; i < numEntries; i++ { + m.Store(int32(i), int32(i)) + } + for i := 0; i < numEntries; i++ { + m.Delete(int32(i)) + if _, ok := m.Load(int32(i)); ok { + t.Fatalf("value was not expected for %d", i) + } + } +} + +func TestMapStructStoreThenDelete(t *testing.T) { + const numEntries = 1000 + m := NewMap[point, string]() + for i := 0; i < numEntries; i++ { + m.Store(point{int32(i), 42}, strconv.Itoa(i)) + } + for i := 0; i < numEntries; i++ { + m.Delete(point{int32(i), 42}) + if _, ok := m.Load(point{int32(i), 42}); ok { + t.Fatalf("value was not expected for %d", i) + } + } +} + +func TestMapStringStoreThenLoadAndDelete(t *testing.T) { + const numEntries = 1000 + m := NewMap[string, int]() + for i := 0; i < numEntries; i++ { + m.Store(strconv.Itoa(i), i) + } + for i := 0; i < numEntries; i++ { + if v, loaded := m.LoadAndDelete(strconv.Itoa(i)); !loaded || v != i { + t.Fatalf("value was not found or different for %d: %v", i, v) + } + if _, ok := m.Load(strconv.Itoa(i)); ok { + t.Fatalf("value was not expected for %d", i) + } + } +} + +func TestMapIntStoreThenLoadAndDelete(t *testing.T) { + const numEntries = 1000 + m := NewMap[int, int]() + for i := 0; i < numEntries; i++ { + m.Store(i, i) + } + for i := 0; i < numEntries; i++ { + if _, loaded := m.LoadAndDelete(i); !loaded { + t.Fatalf("value was not found for %d", i) + } + if _, ok := m.Load(i); ok { + t.Fatalf("value was not expected for %d", i) + } + } +} + +func TestMapStructStoreThenLoadAndDelete(t *testing.T) { + const numEntries = 1000 + m := NewMap[point, int]() + for i := 0; i < numEntries; i++ { + m.Store(point{42, int32(i)}, i) + } + for i := 0; i < numEntries; i++ { + if _, loaded := m.LoadAndDelete(point{42, int32(i)}); !loaded { + t.Fatalf("value was not found for %d", i) + } + if _, ok := m.Load(point{42, int32(i)}); ok { + t.Fatalf("value was not expected for %d", i) + } + } +} + +func TestMapStoreThenParallelDelete_DoesNotShrinkBelowMinTableLen(t *testing.T) { + const numEntries = 1000 + m := NewMap[int, int]() + for i := 0; i < numEntries; i++ { + m.Store(i, i) + } + + cdone := make(chan bool) + go func() { + for i := 0; i < numEntries; i++ { + m.Delete(i) + } + cdone <- true + }() + go func() { + for i := 0; i < numEntries; i++ { + m.Delete(i) + } + cdone <- true + }() + + // Wait for the goroutines to finish. + <-cdone + <-cdone + + stats := m.Stats() + if stats.RootBuckets != defaultMinMapTableLen { + t.Fatalf("table length was different from the minimum: %d", stats.RootBuckets) + } +} + +func sizeBasedOnTypedRange(m *Map[string, int]) int { + size := 0 + m.Range(func(key string, value int) bool { + size++ + return true + }) + return size +} + +func TestMapSize(t *testing.T) { + const numEntries = 1000 + m := NewMap[string, int]() + size := m.Size() + if size != 0 { + t.Fatalf("zero size expected: %d", size) + } + expectedSize := 0 + for i := 0; i < numEntries; i++ { + m.Store(strconv.Itoa(i), i) + expectedSize++ + size := m.Size() + if size != expectedSize { + t.Fatalf("size of %d was expected, got: %d", expectedSize, size) + } + rsize := sizeBasedOnTypedRange(m) + if size != rsize { + t.Fatalf("size does not match number of entries in Range: %v, %v", size, rsize) + } + } + for i := 0; i < numEntries; i++ { + m.Delete(strconv.Itoa(i)) + expectedSize-- + size := m.Size() + if size != expectedSize { + t.Fatalf("size of %d was expected, got: %d", expectedSize, size) + } + rsize := sizeBasedOnTypedRange(m) + if size != rsize { + t.Fatalf("size does not match number of entries in Range: %v, %v", size, rsize) + } + } +} + +func TestMapClear(t *testing.T) { + const numEntries = 1000 + m := NewMap[string, int]() + for i := 0; i < numEntries; i++ { + m.Store(strconv.Itoa(i), i) + } + size := m.Size() + if size != numEntries { + t.Fatalf("size of %d was expected, got: %d", numEntries, size) + } + m.Clear() + size = m.Size() + if size != 0 { + t.Fatalf("zero size was expected, got: %d", size) + } + rsize := sizeBasedOnTypedRange(m) + if rsize != 0 { + t.Fatalf("zero number of entries in Range was expected, got: %d", rsize) + } +} + +func assertMapCapacity[K comparable, V any](t *testing.T, m *Map[K, V], expectedCap int) { + stats := m.Stats() + if stats.Capacity != expectedCap { + t.Fatalf("capacity was different from %d: %d", expectedCap, stats.Capacity) + } +} + +func TestNewMapWithPresize(t *testing.T) { + assertMapCapacity(t, NewMap[string, string](), defaultMinMapTableLen*entriesPerMapBucket) + assertMapCapacity(t, NewMap[string, string](WithPresize(0)), defaultMinMapTableLen*entriesPerMapBucket) + assertMapCapacity(t, NewMap[string, string](WithPresize(-100)), defaultMinMapTableLen*entriesPerMapBucket) + assertMapCapacity(t, NewMap[string, string](WithPresize(500)), 1280) + assertMapCapacity(t, NewMap[int, int](WithPresize(1_000_000)), 2621440) + assertMapCapacity(t, NewMap[point, point](WithPresize(100)), 160) +} + +func TestNewMapWithPresize_DoesNotShrinkBelowMinTableLen(t *testing.T) { + const minTableLen = 1024 + const numEntries = int(minTableLen * entriesPerMapBucket * mapLoadFactor) + m := NewMap[int, int](WithPresize(numEntries)) + for i := 0; i < 2*numEntries; i++ { + m.Store(i, i) + } + + stats := m.Stats() + if stats.RootBuckets <= minTableLen { + t.Fatalf("table did not grow: %d", stats.RootBuckets) + } + + for i := 0; i < 2*numEntries; i++ { + m.Delete(i) + } + + stats = m.Stats() + if stats.RootBuckets != minTableLen { + t.Fatalf("table length was different from the minimum: %d", stats.RootBuckets) + } +} + +func TestNewMapGrowOnly_OnlyShrinksOnClear(t *testing.T) { + const minTableLen = 128 + const numEntries = minTableLen * entriesPerMapBucket + m := NewMap[int, int](WithPresize(numEntries), WithGrowOnly()) + + stats := m.Stats() + initialTableLen := stats.RootBuckets + + for i := 0; i < 2*numEntries; i++ { + m.Store(i, i) + } + stats = m.Stats() + maxTableLen := stats.RootBuckets + if maxTableLen <= minTableLen { + t.Fatalf("table did not grow: %d", maxTableLen) + } + + for i := 0; i < numEntries; i++ { + m.Delete(i) + } + stats = m.Stats() + if stats.RootBuckets != maxTableLen { + t.Fatalf("table length was different from the expected: %d", stats.RootBuckets) + } + + m.Clear() + stats = m.Stats() + if stats.RootBuckets != initialTableLen { + t.Fatalf("table length was different from the initial: %d", stats.RootBuckets) + } +} + +func TestMapResize(t *testing.T) { + testMapResize(t, NewMap[string, int]()) +} + +func testMapResize(t *testing.T, m *Map[string, int]) { + const numEntries = 100_000 + + for i := 0; i < numEntries; i++ { + m.Store(strconv.Itoa(i), i) + } + stats := m.Stats() + if stats.Size != numEntries { + t.Fatalf("size was too small: %d", stats.Size) + } + expectedCapacity := int(math.RoundToEven(mapLoadFactor+1)) * stats.RootBuckets * entriesPerMapBucket + if stats.Capacity > expectedCapacity { + t.Fatalf("capacity was too large: %d, expected: %d", stats.Capacity, expectedCapacity) + } + if stats.RootBuckets <= defaultMinMapTableLen { + t.Fatalf("table was too small: %d", stats.RootBuckets) + } + if stats.TotalGrowths == 0 { + t.Fatalf("non-zero total growths expected: %d", stats.TotalGrowths) + } + if stats.TotalShrinks > 0 { + t.Fatalf("zero total shrinks expected: %d", stats.TotalShrinks) + } + // This is useful when debugging table resize and occupancy. + // Use -v flag to see the output. + t.Log(stats.ToString()) + + for i := 0; i < numEntries; i++ { + m.Delete(strconv.Itoa(i)) + } + stats = m.Stats() + if stats.Size > 0 { + t.Fatalf("zero size was expected: %d", stats.Size) + } + expectedCapacity = stats.RootBuckets * entriesPerMapBucket + if stats.Capacity != expectedCapacity { + t.Fatalf("capacity was too large: %d, expected: %d", stats.Capacity, expectedCapacity) + } + if stats.RootBuckets != defaultMinMapTableLen { + t.Fatalf("table was too large: %d", stats.RootBuckets) + } + if stats.TotalShrinks == 0 { + t.Fatalf("non-zero total shrinks expected: %d", stats.TotalShrinks) + } + t.Log(stats.ToString()) +} + +func TestMapResize_CounterLenLimit(t *testing.T) { + const numEntries = 1_000_000 + m := NewMap[string, string]() + + for i := 0; i < numEntries; i++ { + m.Store("foo"+strconv.Itoa(i), "bar"+strconv.Itoa(i)) + } + stats := m.Stats() + if stats.Size != numEntries { + t.Fatalf("size was too small: %d", stats.Size) + } + if stats.CounterLen != maxMapCounterLen { + t.Fatalf("number of counter stripes was too large: %d, expected: %d", + stats.CounterLen, maxMapCounterLen) + } +} + +func parallelSeqMapGrower(m *Map[int, int], numEntries int, positive bool, cdone chan bool) { + for i := 0; i < numEntries; i++ { + if positive { + m.Store(i, i) + } else { + m.Store(-i, -i) + } + } + cdone <- true +} + +func TestMapParallelGrowth_GrowOnly(t *testing.T) { + const numEntries = 100_000 + m := NewMap[int, int]() + cdone := make(chan bool) + go parallelSeqMapGrower(m, numEntries, true, cdone) + go parallelSeqMapGrower(m, numEntries, false, cdone) + // Wait for the goroutines to finish. + <-cdone + <-cdone + // Verify map contents. + for i := -numEntries + 1; i < numEntries; i++ { + v, ok := m.Load(i) + if !ok { + t.Fatalf("value not found for %d", i) + } + if v != i { + t.Fatalf("values do not match for %d: %v", i, v) + } + } + if s := m.Size(); s != 2*numEntries-1 { + t.Fatalf("unexpected size: %v", s) + } +} + +func parallelRandMapResizer(t *testing.T, m *Map[string, int], numIters, numEntries int, cdone chan bool) { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + for i := 0; i < numIters; i++ { + coin := r.Int63n(2) + for j := 0; j < numEntries; j++ { + if coin == 1 { + m.Store(strconv.Itoa(j), j) + } else { + m.Delete(strconv.Itoa(j)) + } + } + } + cdone <- true +} + +func TestMapParallelGrowth(t *testing.T) { + const numIters = 1_000 + const numEntries = 2 * entriesPerMapBucket * defaultMinMapTableLen + m := NewMap[string, int]() + cdone := make(chan bool) + go parallelRandMapResizer(t, m, numIters, numEntries, cdone) + go parallelRandMapResizer(t, m, numIters, numEntries, cdone) + // Wait for the goroutines to finish. + <-cdone + <-cdone + // Verify map contents. + for i := 0; i < numEntries; i++ { + v, ok := m.Load(strconv.Itoa(i)) + if !ok { + // The entry may be deleted and that's ok. + continue + } + if v != i { + t.Fatalf("values do not match for %d: %v", i, v) + } + } + s := m.Size() + if s > numEntries { + t.Fatalf("unexpected size: %v", s) + } + rs := sizeBasedOnTypedRange(m) + if s != rs { + t.Fatalf("size does not match number of entries in Range: %v, %v", s, rs) + } +} + +func parallelRandMapClearer(t *testing.T, m *Map[string, int], numIters, numEntries int, cdone chan bool) { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + for i := 0; i < numIters; i++ { + coin := r.Int63n(2) + for j := 0; j < numEntries; j++ { + if coin == 1 { + m.Store(strconv.Itoa(j), j) + } else { + m.Clear() + } + } + } + cdone <- true +} + +func TestMapParallelClear(t *testing.T) { + const numIters = 100 + const numEntries = 1_000 + m := NewMap[string, int]() + cdone := make(chan bool) + go parallelRandMapClearer(t, m, numIters, numEntries, cdone) + go parallelRandMapClearer(t, m, numIters, numEntries, cdone) + // Wait for the goroutines to finish. + <-cdone + <-cdone + // Verify map size. + s := m.Size() + if s > numEntries { + t.Fatalf("unexpected size: %v", s) + } + rs := sizeBasedOnTypedRange(m) + if s != rs { + t.Fatalf("size does not match number of entries in Range: %v, %v", s, rs) + } +} + +func parallelSeqMapStorer(t *testing.T, m *Map[string, int], storeEach, numIters, numEntries int, cdone chan bool) { + for i := 0; i < numIters; i++ { + for j := 0; j < numEntries; j++ { + if storeEach == 0 || j%storeEach == 0 { + m.Store(strconv.Itoa(j), j) + // Due to atomic snapshots we must see a ""/j pair. + v, ok := m.Load(strconv.Itoa(j)) + if !ok { + t.Errorf("value was not found for %d", j) + break + } + if v != j { + t.Errorf("value was not expected for %d: %d", j, v) + break + } + } + } + } + cdone <- true +} + +func TestMapParallelStores(t *testing.T) { + const numStorers = 4 + const numIters = 10_000 + const numEntries = 100 + m := NewMap[string, int]() + cdone := make(chan bool) + for i := 0; i < numStorers; i++ { + go parallelSeqMapStorer(t, m, i, numIters, numEntries, cdone) + } + // Wait for the goroutines to finish. + for i := 0; i < numStorers; i++ { + <-cdone + } + // Verify map contents. + for i := 0; i < numEntries; i++ { + v, ok := m.Load(strconv.Itoa(i)) + if !ok { + t.Fatalf("value not found for %d", i) + } + if v != i { + t.Fatalf("values do not match for %d: %v", i, v) + } + } +} + +func parallelRandMapStorer(t *testing.T, m *Map[string, int], numIters, numEntries int, cdone chan bool) { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + for i := 0; i < numIters; i++ { + j := r.Intn(numEntries) + if v, loaded := m.LoadOrStore(strconv.Itoa(j), j); loaded { + if v != j { + t.Errorf("value was not expected for %d: %d", j, v) + } + } + } + cdone <- true +} + +func parallelRandMapDeleter(t *testing.T, m *Map[string, int], numIters, numEntries int, cdone chan bool) { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + for i := 0; i < numIters; i++ { + j := r.Intn(numEntries) + if v, loaded := m.LoadAndDelete(strconv.Itoa(j)); loaded { + if v != j { + t.Errorf("value was not expected for %d: %d", j, v) + } + } + } + cdone <- true +} + +func parallelMapLoader(t *testing.T, m *Map[string, int], numIters, numEntries int, cdone chan bool) { + for i := 0; i < numIters; i++ { + for j := 0; j < numEntries; j++ { + // Due to atomic snapshots we must either see no entry, or a ""/j pair. + if v, ok := m.Load(strconv.Itoa(j)); ok { + if v != j { + t.Errorf("value was not expected for %d: %d", j, v) + } + } + } + } + cdone <- true +} + +func TestMapAtomicSnapshot(t *testing.T) { + const numIters = 100_000 + const numEntries = 100 + m := NewMap[string, int]() + cdone := make(chan bool) + // Update or delete random entry in parallel with loads. + go parallelRandMapStorer(t, m, numIters, numEntries, cdone) + go parallelRandMapDeleter(t, m, numIters, numEntries, cdone) + go parallelMapLoader(t, m, numIters, numEntries, cdone) + // Wait for the goroutines to finish. + for i := 0; i < 3; i++ { + <-cdone + } +} + +func TestMapParallelStoresAndDeletes(t *testing.T) { + const numWorkers = 2 + const numIters = 100_000 + const numEntries = 1000 + m := NewMap[string, int]() + cdone := make(chan bool) + // Update random entry in parallel with deletes. + for i := 0; i < numWorkers; i++ { + go parallelRandMapStorer(t, m, numIters, numEntries, cdone) + go parallelRandMapDeleter(t, m, numIters, numEntries, cdone) + } + // Wait for the goroutines to finish. + for i := 0; i < 2*numWorkers; i++ { + <-cdone + } +} + +func parallelMapComputer(m *Map[uint64, uint64], numIters, numEntries int, cdone chan bool) { + for i := 0; i < numIters; i++ { + for j := 0; j < numEntries; j++ { + m.Compute(uint64(j), func(oldValue uint64, loaded bool) (newValue uint64, op ComputeOp) { + return oldValue + 1, UpdateOp + }) + } + } + cdone <- true +} + +func TestMapParallelComputes(t *testing.T) { + const numWorkers = 4 // Also stands for numEntries. + const numIters = 10_000 + m := NewMap[uint64, uint64]() + cdone := make(chan bool) + for i := 0; i < numWorkers; i++ { + go parallelMapComputer(m, numIters, numWorkers, cdone) + } + // Wait for the goroutines to finish. + for i := 0; i < numWorkers; i++ { + <-cdone + } + // Verify map contents. + for i := 0; i < numWorkers; i++ { + v, ok := m.Load(uint64(i)) + if !ok { + t.Fatalf("value not found for %d", i) + } + if v != numWorkers*numIters { + t.Fatalf("values do not match for %d: %v", i, v) + } + } +} + +func parallelRangeMapStorer(m *Map[int, int], numEntries int, stopFlag *int64, cdone chan bool) { + for { + for i := 0; i < numEntries; i++ { + m.Store(i, i) + } + if atomic.LoadInt64(stopFlag) != 0 { + break + } + } + cdone <- true +} + +func parallelRangeMapDeleter(m *Map[int, int], numEntries int, stopFlag *int64, cdone chan bool) { + for { + for i := 0; i < numEntries; i++ { + m.Delete(i) + } + if atomic.LoadInt64(stopFlag) != 0 { + break + } + } + cdone <- true +} + +func TestMapParallelRange(t *testing.T) { + const numEntries = 10_000 + m := NewMap[int, int](WithPresize(numEntries)) + for i := 0; i < numEntries; i++ { + m.Store(i, i) + } + // Start goroutines that would be storing and deleting items in parallel. + cdone := make(chan bool) + stopFlag := int64(0) + go parallelRangeMapStorer(m, numEntries, &stopFlag, cdone) + go parallelRangeMapDeleter(m, numEntries, &stopFlag, cdone) + // Iterate the map and verify that no duplicate keys were met. + met := make(map[int]int) + m.Range(func(key int, value int) bool { + if key != value { + t.Fatalf("got unexpected value for key %d: %d", key, value) + return false + } + met[key] += 1 + return true + }) + if len(met) == 0 { + t.Fatal("no entries were met when iterating") + } + for k, c := range met { + if c != 1 { + t.Fatalf("met key %d multiple times: %d", k, c) + } + } + // Make sure that both goroutines finish. + atomic.StoreInt64(&stopFlag, 1) + <-cdone + <-cdone +} + +func parallelMapShrinker(t *testing.T, m *Map[uint64, *point], numIters, numEntries int, stopFlag *int64, cdone chan bool) { + for i := 0; i < numIters; i++ { + for j := 0; j < numEntries; j++ { + if p, loaded := m.LoadOrStore(uint64(j), &point{int32(j), int32(j)}); loaded { + t.Errorf("value was present for %d: %v", j, p) + } + } + for j := 0; j < numEntries; j++ { + m.Delete(uint64(j)) + } + } + atomic.StoreInt64(stopFlag, 1) + cdone <- true +} + +func parallelMapUpdater(t *testing.T, m *Map[uint64, *point], idx int, stopFlag *int64, cdone chan bool) { + for atomic.LoadInt64(stopFlag) != 1 { + sleepUs := int(randv2.Uint64() % 10) + if p, loaded := m.LoadOrStore(uint64(idx), &point{int32(idx), int32(idx)}); loaded { + t.Errorf("value was present for %d: %v", idx, p) + } + time.Sleep(time.Duration(sleepUs) * time.Microsecond) + if _, ok := m.Load(uint64(idx)); !ok { + t.Errorf("value was not found for %d", idx) + } + m.Delete(uint64(idx)) + } + cdone <- true +} + +func TestMapDoesNotLoseEntriesOnResize(t *testing.T) { + const numIters = 10_000 + const numEntries = 128 + m := NewMap[uint64, *point]() + cdone := make(chan bool) + stopFlag := int64(0) + go parallelMapShrinker(t, m, numIters, numEntries, &stopFlag, cdone) + go parallelMapUpdater(t, m, numEntries, &stopFlag, cdone) + // Wait for the goroutines to finish. + <-cdone + <-cdone + // Verify map contents. + if s := m.Size(); s != 0 { + t.Fatalf("map is not empty: %d", s) + } +} + +func TestMapStats(t *testing.T) { + m := NewMap[int, int]() + + stats := m.Stats() + if stats.RootBuckets != defaultMinMapTableLen { + t.Fatalf("unexpected number of root buckets: %d", stats.RootBuckets) + } + if stats.TotalBuckets != stats.RootBuckets { + t.Fatalf("unexpected number of total buckets: %d", stats.TotalBuckets) + } + if stats.EmptyBuckets != stats.RootBuckets { + t.Fatalf("unexpected number of empty buckets: %d", stats.EmptyBuckets) + } + if stats.Capacity != entriesPerMapBucket*defaultMinMapTableLen { + t.Fatalf("unexpected capacity: %d", stats.Capacity) + } + if stats.Size != 0 { + t.Fatalf("unexpected size: %d", stats.Size) + } + if stats.Counter != 0 { + t.Fatalf("unexpected counter: %d", stats.Counter) + } + if stats.CounterLen != 8 { + t.Fatalf("unexpected counter length: %d", stats.CounterLen) + } + + for i := 0; i < 200; i++ { + m.Store(i, i) + } + + stats = m.Stats() + if stats.RootBuckets != 2*defaultMinMapTableLen { + t.Fatalf("unexpected number of root buckets: %d", stats.RootBuckets) + } + if stats.TotalBuckets < stats.RootBuckets { + t.Fatalf("unexpected number of total buckets: %d", stats.TotalBuckets) + } + if stats.EmptyBuckets >= stats.RootBuckets { + t.Fatalf("unexpected number of empty buckets: %d", stats.EmptyBuckets) + } + if stats.Capacity < 2*entriesPerMapBucket*defaultMinMapTableLen { + t.Fatalf("unexpected capacity: %d", stats.Capacity) + } + if stats.Size != 200 { + t.Fatalf("unexpected size: %d", stats.Size) + } + if stats.Counter != 200 { + t.Fatalf("unexpected counter: %d", stats.Counter) + } + if stats.CounterLen != 8 { + t.Fatalf("unexpected counter length: %d", stats.CounterLen) + } +} + +func TestToPlainMap_NilPointer(t *testing.T) { + pm := ToPlainMap[int, int](nil) + if len(pm) != 0 { + t.Fatalf("got unexpected size of nil map copy: %d", len(pm)) + } +} + +func TestToPlainMap(t *testing.T) { + const numEntries = 1000 + m := NewMap[int, int]() + for i := 0; i < numEntries; i++ { + m.Store(i, i) + } + pm := ToPlainMap[int, int](m) + if len(pm) != numEntries { + t.Fatalf("got unexpected size of nil map copy: %d", len(pm)) + } + for i := 0; i < numEntries; i++ { + if v := pm[i]; v != i { + t.Fatalf("unexpected value for key %d: %d", i, v) + } + } +} + +func BenchmarkMap_NoWarmUp(b *testing.B) { + for _, bc := range benchmarkCases { + if bc.readPercentage == 100 { + // This benchmark doesn't make sense without a warm-up. + continue + } + b.Run(bc.name, func(b *testing.B) { + m := NewMap[string, int]() + benchmarkMapStringKeys(b, func(k string) (int, bool) { + return m.Load(k) + }, func(k string, v int) { + m.Store(k, v) + }, func(k string) { + m.Delete(k) + }, bc.readPercentage) + }) + } +} + +func BenchmarkMap_WarmUp(b *testing.B) { + for _, bc := range benchmarkCases { + b.Run(bc.name, func(b *testing.B) { + m := NewMap[string, int](WithPresize(benchmarkNumEntries)) + for i := 0; i < benchmarkNumEntries; i++ { + m.Store(benchmarkKeyPrefix+strconv.Itoa(i), i) + } + b.ResetTimer() + benchmarkMapStringKeys(b, func(k string) (int, bool) { + return m.Load(k) + }, func(k string, v int) { + m.Store(k, v) + }, func(k string) { + m.Delete(k) + }, bc.readPercentage) + }) + } +} + +func benchmarkMapStringKeys( + b *testing.B, + loadFn func(k string) (int, bool), + storeFn func(k string, v int), + deleteFn func(k string), + readPercentage int, +) { + runParallel(b, func(pb *testing.PB) { + // convert percent to permille to support 99% case + storeThreshold := 10 * readPercentage + deleteThreshold := 10*readPercentage + ((1000 - 10*readPercentage) / 2) + for pb.Next() { + op := int(randv2.Uint64() % 1000) + i := int(randv2.Uint64() % benchmarkNumEntries) + if op >= deleteThreshold { + deleteFn(benchmarkKeys[i]) + } else if op >= storeThreshold { + storeFn(benchmarkKeys[i], i) + } else { + loadFn(benchmarkKeys[i]) + } + } + }) +} + +func BenchmarkMapInt_NoWarmUp(b *testing.B) { + for _, bc := range benchmarkCases { + if bc.readPercentage == 100 { + // This benchmark doesn't make sense without a warm-up. + continue + } + b.Run(bc.name, func(b *testing.B) { + m := NewMap[int, int]() + benchmarkMapIntKeys(b, func(k int) (int, bool) { + return m.Load(k) + }, func(k int, v int) { + m.Store(k, v) + }, func(k int) { + m.Delete(k) + }, bc.readPercentage) + }) + } +} + +func BenchmarkMapInt_WarmUp(b *testing.B) { + for _, bc := range benchmarkCases { + b.Run(bc.name, func(b *testing.B) { + m := NewMap[int, int](WithPresize(benchmarkNumEntries)) + for i := 0; i < benchmarkNumEntries; i++ { + m.Store(i, i) + } + b.ResetTimer() + benchmarkMapIntKeys(b, func(k int) (int, bool) { + return m.Load(k) + }, func(k int, v int) { + m.Store(k, v) + }, func(k int) { + m.Delete(k) + }, bc.readPercentage) + }) + } +} + +func BenchmarkIntMapStandard_NoWarmUp(b *testing.B) { + for _, bc := range benchmarkCases { + if bc.readPercentage == 100 { + // This benchmark doesn't make sense without a warm-up. + continue + } + b.Run(bc.name, func(b *testing.B) { + var m sync.Map + benchmarkMapIntKeys(b, func(k int) (value int, ok bool) { + v, ok := m.Load(k) + if ok { + return v.(int), ok + } else { + return 0, false + } + }, func(k int, v int) { + m.Store(k, v) + }, func(k int) { + m.Delete(k) + }, bc.readPercentage) + }) + } +} + +// This is a nice scenario for sync.Map since a lot of updates +// will hit the readOnly part of the map. +func BenchmarkIntMapStandard_WarmUp(b *testing.B) { + for _, bc := range benchmarkCases { + b.Run(bc.name, func(b *testing.B) { + var m sync.Map + for i := 0; i < benchmarkNumEntries; i++ { + m.Store(i, i) + } + b.ResetTimer() + benchmarkMapIntKeys(b, func(k int) (value int, ok bool) { + v, ok := m.Load(k) + if ok { + return v.(int), ok + } else { + return 0, false + } + }, func(k int, v int) { + m.Store(k, v) + }, func(k int) { + m.Delete(k) + }, bc.readPercentage) + }) + } +} + +func benchmarkMapIntKeys( + b *testing.B, + loadFn func(k int) (int, bool), + storeFn func(k int, v int), + deleteFn func(k int), + readPercentage int, +) { + runParallel(b, func(pb *testing.PB) { + // convert percent to permille to support 99% case + storeThreshold := 10 * readPercentage + deleteThreshold := 10*readPercentage + ((1000 - 10*readPercentage) / 2) + for pb.Next() { + op := int(randv2.Uint64() % 1000) + i := int(randv2.Uint64() % benchmarkNumEntries) + if op >= deleteThreshold { + deleteFn(i) + } else if op >= storeThreshold { + storeFn(i, i) + } else { + loadFn(i) + } + } + }) +} + +func BenchmarkMapRange(b *testing.B) { + m := NewMap[string, int](WithPresize(benchmarkNumEntries)) + for i := 0; i < benchmarkNumEntries; i++ { + m.Store(benchmarkKeys[i], i) + } + b.ResetTimer() + runParallel(b, func(pb *testing.PB) { + foo := 0 + for pb.Next() { + m.Range(func(key string, value int) bool { + foo++ + return true + }) + _ = foo + } + }) +} + +// Benchmarks noop performance of Compute +func BenchmarkCompute(b *testing.B) { + tests := []struct { + Name string + Op ComputeOp + }{ + { + Name: "UpdateOp", + Op: UpdateOp, + }, + { + Name: "CancelOp", + Op: CancelOp, + }, + } + for _, test := range tests { + b.Run("op="+test.Name, func(b *testing.B) { + m := NewMap[struct{}, bool]() + m.Store(struct{}{}, true) + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.Compute(struct{}{}, func(oldValue bool, loaded bool) (newValue bool, op ComputeOp) { + return oldValue, test.Op + }) + } + }) + } +} + +func TestNextPowOf2(t *testing.T) { + if nextPowOf2(0) != 1 { + t.Error("nextPowOf2 failed") + } + if nextPowOf2(1) != 1 { + t.Error("nextPowOf2 failed") + } + if nextPowOf2(2) != 2 { + t.Error("nextPowOf2 failed") + } + if nextPowOf2(3) != 4 { + t.Error("nextPowOf2 failed") + } +} + +func TestBroadcast(t *testing.T) { + testCases := []struct { + input uint8 + expected uint64 + }{ + { + input: 0, + expected: 0, + }, + { + input: 1, + expected: 0x0101010101010101, + }, + { + input: 2, + expected: 0x0202020202020202, + }, + { + input: 42, + expected: 0x2a2a2a2a2a2a2a2a, + }, + { + input: 127, + expected: 0x7f7f7f7f7f7f7f7f, + }, + { + input: 255, + expected: 0xffffffffffffffff, + }, + } + + for _, tc := range testCases { + t.Run(strconv.Itoa(int(tc.input)), func(t *testing.T) { + if broadcast(tc.input) != tc.expected { + t.Errorf("unexpected result: %x", broadcast(tc.input)) + } + }) + } +} + +func TestFirstMarkedByteIndex(t *testing.T) { + testCases := []struct { + input uint64 + expected int + }{ + { + input: 0, + expected: 8, + }, + { + input: 0x8080808080808080, + expected: 0, + }, + { + input: 0x0000000000000080, + expected: 0, + }, + { + input: 0x0000000000008000, + expected: 1, + }, + { + input: 0x0000000000800000, + expected: 2, + }, + { + input: 0x0000000080000000, + expected: 3, + }, + { + input: 0x0000008000000000, + expected: 4, + }, + { + input: 0x0000800000000000, + expected: 5, + }, + { + input: 0x0080000000000000, + expected: 6, + }, + { + input: 0x8000000000000000, + expected: 7, + }, + } + + for _, tc := range testCases { + t.Run(strconv.Itoa(int(tc.input)), func(t *testing.T) { + if firstMarkedByteIndex(tc.input) != tc.expected { + t.Errorf("unexpected result: %x", firstMarkedByteIndex(tc.input)) + } + }) + } +} + +func TestMarkZeroBytes(t *testing.T) { + testCases := []struct { + input uint64 + expected uint64 + }{ + { + input: 0xffffffffffffffff, + expected: 0, + }, + { + input: 0, + expected: 0x8080808080808080, + }, + { + input: 1, + expected: 0x8080808080808000, + }, + { + input: 1 << 9, + expected: 0x8080808080800080, + }, + { + input: 1 << 17, + expected: 0x8080808080008080, + }, + { + input: 1 << 25, + expected: 0x8080808000808080, + }, + { + input: 1 << 33, + expected: 0x8080800080808080, + }, + { + input: 1 << 41, + expected: 0x8080008080808080, + }, + { + input: 1 << 49, + expected: 0x8000808080808080, + }, + { + input: 1 << 57, + expected: 0x0080808080808080, + }, + // false positive + { + input: 0x0100, + expected: 0x8080808080808080, + }, + } + + for _, tc := range testCases { + t.Run(strconv.Itoa(int(tc.input)), func(t *testing.T) { + if markZeroBytes(tc.input) != tc.expected { + t.Errorf("unexpected result: %x", markZeroBytes(tc.input)) + } + }) + } +} + +func TestSetByte(t *testing.T) { + testCases := []struct { + word uint64 + b uint8 + idx int + expected uint64 + }{ + { + word: 0xffffffffffffffff, + b: 0, + idx: 0, + expected: 0xffffffffffffff00, + }, + { + word: 0xffffffffffffffff, + b: 1, + idx: 1, + expected: 0xffffffffffff01ff, + }, + { + word: 0xffffffffffffffff, + b: 2, + idx: 2, + expected: 0xffffffffff02ffff, + }, + { + word: 0xffffffffffffffff, + b: 3, + idx: 3, + expected: 0xffffffff03ffffff, + }, + { + word: 0xffffffffffffffff, + b: 4, + idx: 4, + expected: 0xffffff04ffffffff, + }, + { + word: 0xffffffffffffffff, + b: 5, + idx: 5, + expected: 0xffff05ffffffffff, + }, + { + word: 0xffffffffffffffff, + b: 6, + idx: 6, + expected: 0xff06ffffffffffff, + }, + { + word: 0xffffffffffffffff, + b: 7, + idx: 7, + expected: 0x07ffffffffffffff, + }, + { + word: 0, + b: 0xff, + idx: 7, + expected: 0xff00000000000000, + }, + } + + for _, tc := range testCases { + t.Run(strconv.Itoa(int(tc.word)), func(t *testing.T) { + if setByte(tc.word, tc.b, tc.idx) != tc.expected { + t.Errorf("unexpected result: %x", setByte(tc.word, tc.b, tc.idx)) + } + }) + } +} diff --git a/component/loopback/detector.go b/component/loopback/detector.go index c639ab22..67ffe48a 100644 --- a/component/loopback/detector.go +++ b/component/loopback/detector.go @@ -8,11 +8,10 @@ import ( "strconv" "github.com/metacubex/mihomo/common/callback" + "github.com/metacubex/mihomo/common/xsync" "github.com/metacubex/mihomo/component/iface" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/constant/features" - - "github.com/puzpuzpuz/xsync/v3" ) var disableLoopBackDetector, _ = strconv.ParseBool(os.Getenv("DISABLE_LOOPBACK_DETECTOR")) @@ -26,8 +25,8 @@ func init() { var ErrReject = errors.New("reject loopback connection") type Detector struct { - connMap *xsync.MapOf[netip.AddrPort, struct{}] - packetConnMap *xsync.MapOf[uint16, struct{}] + connMap *xsync.Map[netip.AddrPort, struct{}] + packetConnMap *xsync.Map[uint16, struct{}] } func NewDetector() *Detector { @@ -35,8 +34,8 @@ func NewDetector() *Detector { return nil } return &Detector{ - connMap: xsync.NewMapOf[netip.AddrPort, struct{}](), - packetConnMap: xsync.NewMapOf[uint16, struct{}](), + connMap: xsync.NewMap[netip.AddrPort, struct{}](), + packetConnMap: xsync.NewMap[uint16, struct{}](), } } diff --git a/component/nat/table.go b/component/nat/table.go index 66241fb4..eeeba612 100644 --- a/component/nat/table.go +++ b/component/nat/table.go @@ -4,27 +4,26 @@ import ( "net" "sync" + "github.com/metacubex/mihomo/common/xsync" C "github.com/metacubex/mihomo/constant" - - "github.com/puzpuzpuz/xsync/v3" ) type Table struct { - mapping *xsync.MapOf[string, *entry] + mapping *xsync.Map[string, *entry] } type entry struct { PacketSender C.PacketSender - LocalUDPConnMap *xsync.MapOf[string, *net.UDPConn] - LocalLockMap *xsync.MapOf[string, *sync.Cond] + LocalUDPConnMap *xsync.Map[string, *net.UDPConn] + LocalLockMap *xsync.Map[string, *sync.Cond] } func (t *Table) GetOrCreate(key string, maker func() C.PacketSender) (C.PacketSender, bool) { - item, loaded := t.mapping.LoadOrCompute(key, func() *entry { + item, loaded := t.mapping.LoadOrStoreFn(key, func() *entry { return &entry{ PacketSender: maker(), - LocalUDPConnMap: xsync.NewMapOf[string, *net.UDPConn](), - LocalLockMap: xsync.NewMapOf[string, *sync.Cond](), + LocalUDPConnMap: xsync.NewMap[string, *net.UDPConn](), + LocalLockMap: xsync.NewMap[string, *sync.Cond](), } }) return item.PacketSender, loaded @@ -68,7 +67,7 @@ func (t *Table) GetOrCreateLockForLocalConn(lAddr, key string) (*sync.Cond, bool if !loaded { return nil, false } - item, loaded := entry.LocalLockMap.LoadOrCompute(key, makeLock) + item, loaded := entry.LocalLockMap.LoadOrStoreFn(key, makeLock) return item, loaded } @@ -99,6 +98,6 @@ func makeLock() *sync.Cond { // New return *Cache func New() *Table { return &Table{ - mapping: xsync.NewMapOf[string, *entry](), + mapping: xsync.NewMap[string, *entry](), } } diff --git a/go.mod b/go.mod index 4bd156cc..a2ba50ca 100644 --- a/go.mod +++ b/go.mod @@ -42,7 +42,6 @@ require ( github.com/mroth/weightedrand/v2 v2.1.0 github.com/openacid/low v0.1.21 github.com/oschwald/maxminddb-golang v1.12.0 // lastest version compatible with golang1.20 - github.com/puzpuzpuz/xsync/v3 v3.5.1 github.com/sagernet/cors v1.2.1 github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a github.com/samber/lo v1.51.0 diff --git a/go.sum b/go.sum index f089d652..060f4a12 100644 --- a/go.sum +++ b/go.sum @@ -167,8 +167,6 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= -github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg= -github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= github.com/sagernet/cors v1.2.1 h1:Cv5Z8y9YSD6Gm+qSpNrL3LO4lD3eQVvbFYJSG7JCMHQ= diff --git a/transport/tuic/v4/client.go b/transport/tuic/v4/client.go index afa83d82..14cd9d2f 100644 --- a/transport/tuic/v4/client.go +++ b/transport/tuic/v4/client.go @@ -14,6 +14,7 @@ import ( atomic2 "github.com/metacubex/mihomo/common/atomic" N "github.com/metacubex/mihomo/common/net" "github.com/metacubex/mihomo/common/pool" + "github.com/metacubex/mihomo/common/xsync" tlsC "github.com/metacubex/mihomo/component/tls" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/log" @@ -21,7 +22,6 @@ import ( "github.com/metacubex/quic-go" "github.com/metacubex/randv2" - "github.com/puzpuzpuz/xsync/v3" ) type ClientOption struct { @@ -48,7 +48,7 @@ type clientImpl struct { openStreams atomic.Int64 closed atomic.Bool - udpInputMap *xsync.MapOf[uint32, net.Conn] + udpInputMap *xsync.Map[uint32, net.Conn] // only ready for PoolClient dialerRef C.Dialer @@ -422,7 +422,7 @@ func NewClient(clientOption *ClientOption, udp bool, dialerRef C.Dialer) *Client ClientOption: clientOption, udp: udp, dialerRef: dialerRef, - udpInputMap: xsync.NewMapOf[uint32, net.Conn](), + udpInputMap: xsync.NewMap[uint32, net.Conn](), } c := &Client{ci} runtime.SetFinalizer(c, closeClient) diff --git a/transport/tuic/v4/server.go b/transport/tuic/v4/server.go index 62ba5a58..6530e413 100644 --- a/transport/tuic/v4/server.go +++ b/transport/tuic/v4/server.go @@ -11,13 +11,13 @@ import ( "github.com/metacubex/mihomo/common/atomic" N "github.com/metacubex/mihomo/common/net" "github.com/metacubex/mihomo/common/pool" + "github.com/metacubex/mihomo/common/xsync" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/transport/socks5" "github.com/metacubex/mihomo/transport/tuic/common" "github.com/gofrs/uuid/v5" "github.com/metacubex/quic-go" - "github.com/puzpuzpuz/xsync/v3" ) type ServerOption struct { @@ -34,7 +34,7 @@ func NewServerHandler(option *ServerOption, quicConn *quic.Conn, uuid uuid.UUID) quicConn: quicConn, uuid: uuid, authCh: make(chan struct{}), - udpInputMap: xsync.NewMapOf[uint32, *atomic.Bool](), + udpInputMap: xsync.NewMap[uint32, *atomic.Bool](), } } @@ -47,7 +47,7 @@ type serverHandler struct { authOk atomic.Bool authOnce sync.Once - udpInputMap *xsync.MapOf[uint32, *atomic.Bool] + udpInputMap *xsync.Map[uint32, *atomic.Bool] } func (s *serverHandler) AuthOk() bool { @@ -80,7 +80,7 @@ func (s *serverHandler) parsePacket(packet *Packet, udpRelayMode common.UdpRelay assocId = packet.ASSOC_ID - writeClosed, _ := s.udpInputMap.LoadOrCompute(assocId, func() *atomic.Bool { return &atomic.Bool{} }) + writeClosed, _ := s.udpInputMap.LoadOrStoreFn(assocId, func() *atomic.Bool { return &atomic.Bool{} }) if writeClosed.Load() { return nil } diff --git a/transport/tuic/v5/client.go b/transport/tuic/v5/client.go index ff6fbc3e..9053efdf 100644 --- a/transport/tuic/v5/client.go +++ b/transport/tuic/v5/client.go @@ -14,6 +14,7 @@ import ( atomic2 "github.com/metacubex/mihomo/common/atomic" N "github.com/metacubex/mihomo/common/net" "github.com/metacubex/mihomo/common/pool" + "github.com/metacubex/mihomo/common/xsync" tlsC "github.com/metacubex/mihomo/component/tls" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/log" @@ -21,7 +22,6 @@ import ( "github.com/metacubex/quic-go" "github.com/metacubex/randv2" - "github.com/puzpuzpuz/xsync/v3" ) type ClientOption struct { @@ -47,7 +47,7 @@ type clientImpl struct { openStreams atomic.Int64 closed atomic.Bool - udpInputMap *xsync.MapOf[uint16, net.Conn] + udpInputMap *xsync.Map[uint16, net.Conn] // only ready for PoolClient dialerRef C.Dialer @@ -406,7 +406,7 @@ func NewClient(clientOption *ClientOption, udp bool, dialerRef C.Dialer) *Client ClientOption: clientOption, udp: udp, dialerRef: dialerRef, - udpInputMap: xsync.NewMapOf[uint16, net.Conn](), + udpInputMap: xsync.NewMap[uint16, net.Conn](), } c := &Client{ci} runtime.SetFinalizer(c, closeClient) diff --git a/transport/tuic/v5/server.go b/transport/tuic/v5/server.go index 31bedf35..73cd179f 100644 --- a/transport/tuic/v5/server.go +++ b/transport/tuic/v5/server.go @@ -10,13 +10,13 @@ import ( "github.com/metacubex/mihomo/adapter/inbound" "github.com/metacubex/mihomo/common/atomic" N "github.com/metacubex/mihomo/common/net" + "github.com/metacubex/mihomo/common/xsync" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/transport/socks5" "github.com/metacubex/mihomo/transport/tuic/common" "github.com/gofrs/uuid/v5" "github.com/metacubex/quic-go" - "github.com/puzpuzpuz/xsync/v3" ) type ServerOption struct { @@ -33,7 +33,7 @@ func NewServerHandler(option *ServerOption, quicConn *quic.Conn, uuid uuid.UUID) quicConn: quicConn, uuid: uuid, authCh: make(chan struct{}), - udpInputMap: xsync.NewMapOf[uint16, *serverUDPInput](), + udpInputMap: xsync.NewMap[uint16, *serverUDPInput](), } } @@ -47,7 +47,7 @@ type serverHandler struct { authUUID atomic.TypedValue[string] authOnce sync.Once - udpInputMap *xsync.MapOf[uint16, *serverUDPInput] + udpInputMap *xsync.Map[uint16, *serverUDPInput] } func (s *serverHandler) AuthOk() bool { @@ -96,7 +96,7 @@ func (s *serverHandler) parsePacket(packet *Packet, udpRelayMode common.UdpRelay assocId = packet.ASSOC_ID - input, _ := s.udpInputMap.LoadOrCompute(assocId, func() *serverUDPInput { return &serverUDPInput{} }) + input, _ := s.udpInputMap.LoadOrStoreFn(assocId, func() *serverUDPInput { return &serverUDPInput{} }) if input.writeClosed.Load() { return nil } diff --git a/tunnel/statistic/manager.go b/tunnel/statistic/manager.go index 3f2770c2..8761fda9 100644 --- a/tunnel/statistic/manager.go +++ b/tunnel/statistic/manager.go @@ -5,8 +5,8 @@ import ( "time" "github.com/metacubex/mihomo/common/atomic" + "github.com/metacubex/mihomo/common/xsync" - "github.com/puzpuzpuz/xsync/v3" "github.com/shirou/gopsutil/v4/process" ) @@ -14,7 +14,7 @@ var DefaultManager *Manager func init() { DefaultManager = &Manager{ - connections: xsync.NewMapOf[string, Tracker](), + connections: xsync.NewMap[string, Tracker](), uploadTemp: atomic.NewInt64(0), downloadTemp: atomic.NewInt64(0), uploadBlip: atomic.NewInt64(0), @@ -28,7 +28,7 @@ func init() { } type Manager struct { - connections *xsync.MapOf[string, Tracker] + connections *xsync.Map[string, Tracker] uploadTemp atomic.Int64 downloadTemp atomic.Int64 uploadBlip atomic.Int64