diff --git a/query/math.go b/query/math.go index a020066c40d..8b3f7560f32 100644 --- a/query/math.go +++ b/query/math.go @@ -94,9 +94,37 @@ func processBinary(mNode *mathTree) error { } if mpl.Len() != 0 || mpr.Len() != 0 { - var wg sync.WaitGroup returnMap := types.NewShardedMap() + // For small inputs (the common case in DQL traversal where math operates on a + // handful of uids), launching 30 goroutines + a WaitGroup costs more than the + // work itself. Process inline below a threshold. + const parallelThreshold = 512 + if mpl.Len()+mpr.Len() < parallelThreshold { + for i := range types.NumShards { + // PeekShard reads only — does not allocate an empty shard map. + mlps := mpl.PeekShard(i) + mprs := mpr.PeekShard(i) + if len(mlps) == 0 && len(mprs) == 0 { + continue + } + // destMapi is written to inside f, so we need GetShardOrNil. + destMapi := returnMap.GetShardOrNil(i) + for k := range mlps { + f(k, &mlps, &mprs, &destMapi) + } + for k := range mprs { + if _, ok := mlps[k]; ok { + continue + } + f(k, &mlps, &mprs, &destMapi) + } + } + mNode.Val = returnMap + return nil + } + + var wg sync.WaitGroup for i := range types.NumShards { wg.Add(1) mlps := mpl.GetShardOrNil(i) diff --git a/types/sharded_map.go b/types/sharded_map.go index 0317ff2c551..8ce19066679 100644 --- a/types/sharded_map.go +++ b/types/sharded_map.go @@ -11,26 +11,33 @@ import ( "sync" ) +const NumShards = 30 + type ShardedMap struct { - shards []map[uint64]Val + // shards is a fixed-size array, not a slice, so allocating a ShardedMap + // performs exactly one heap allocation. Individual shard maps are created + // lazily on first write, so workloads that touch few shards (or none) + // pay near-zero allocation cost. + shards [NumShards]map[uint64]Val } -const NumShards = 30 - func NewShardedMap() *ShardedMap { - shards := make([]map[uint64]Val, NumShards) - for i := range shards { - shards[i] = make(map[uint64]Val) - } - return &ShardedMap{shards: shards} + return &ShardedMap{} } func (s *ShardedMap) Merge(other *ShardedMap, ag func(a, b Val) Val) { // TODO: ideally othermap should be the one which is smaller in size. var wg sync.WaitGroup for i := range s.shards { + // Skip shards that are empty in both maps — no work, no goroutine launch. + if len(other.shards[i]) == 0 { + continue + } wg.Add(1) go func(i int) { + if s.shards[i] == nil { + s.shards[i] = make(map[uint64]Val, len(other.shards[i])) + } for k, v := range other.shards[i] { if _, ok := s.shards[i][k]; ok { s.shards[i][k] = ag(s.shards[i][k], v) @@ -49,13 +56,35 @@ func (s *ShardedMap) IsEmpty() bool { if s == nil { return true } - return len(s.shards) == 0 + for i := range s.shards { + if len(s.shards[i]) > 0 { + return false + } + } + return true } +// GetShardOrNil returns the underlying map for the given shard, creating it +// if necessary. Callers may write through the returned reference; those writes +// persist in the ShardedMap. (The original implementation eagerly created all +// shards, so callers depended on the returned map being a live reference.) func (s *ShardedMap) GetShardOrNil(key int) map[uint64]Val { if s == nil { return make(map[uint64]Val) } + if s.shards[key] == nil { + s.shards[key] = make(map[uint64]Val) + } + return s.shards[key] +} + +// PeekShard returns the underlying shard map without allocating one if it does +// not yet exist. Callers MUST NOT write to the returned map — use GetShardOrNil +// for that. This is the right call for iterate-only / range-only access. +func (s *ShardedMap) PeekShard(key int) map[uint64]Val { + if s == nil { + return nil + } return s.shards[key] } @@ -65,24 +94,30 @@ func (s *ShardedMap) init() { } } -func (s *ShardedMap) getShard(key uint64) map[uint64]Val { - return s.shards[key%NumShards] +func (s *ShardedMap) getShardIdx(key uint64) int { + return int(key % NumShards) } func (s *ShardedMap) Set(key uint64, value Val) { if s == nil { s.init() } - shard := s.getShard(key) - shard[key] = value + idx := s.getShardIdx(key) + if s.shards[idx] == nil { + s.shards[idx] = make(map[uint64]Val) + } + s.shards[idx][key] = value } func (s *ShardedMap) Get(key uint64) (Val, bool) { if s == nil { return Val{}, false } - shard := s.getShard(key) - val, ok := shard[key] + idx := s.getShardIdx(key) + if s.shards[idx] == nil { + return Val{}, false + } + val, ok := s.shards[idx][key] return val, ok } @@ -91,8 +126,8 @@ func (s *ShardedMap) Len() int { return 0 } var count int - for _, shard := range s.shards { - count += len(shard) + for i := range s.shards { + count += len(s.shards[i]) } return count } @@ -101,8 +136,11 @@ func (s *ShardedMap) Iterate(f func(uint64, Val) error) error { if s == nil { return nil } - for _, shard := range s.shards { - for k, v := range shard { + for i := range s.shards { + if s.shards[i] == nil { + continue + } + for k, v := range s.shards[i] { if err := f(k, v); err != nil { return err } diff --git a/types/sharded_map_test.go b/types/sharded_map_test.go new file mode 100644 index 00000000000..5c552710852 --- /dev/null +++ b/types/sharded_map_test.go @@ -0,0 +1,118 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package types + +import ( + "testing" +) + +func TestShardedMapBasic(t *testing.T) { + m := NewShardedMap() + v := Val{Tid: IntID, Value: int64(42)} + m.Set(7, v) + + got, ok := m.Get(7) + if !ok || got.Value != int64(42) { + t.Fatalf("Get(7) = %v, %v; want 42, true", got, ok) + } + if _, ok := m.Get(8); ok { + t.Fatalf("Get(8) = ok; want missing") + } + if m.Len() != 1 { + t.Fatalf("Len = %d; want 1", m.Len()) + } +} + +func TestShardedMapEmpty(t *testing.T) { + var m *ShardedMap + if !m.IsEmpty() { + t.Fatalf("nil IsEmpty should be true") + } + if _, ok := m.Get(1); ok { + t.Fatalf("nil Get returned ok=true") + } + if m.Len() != 0 { + t.Fatalf("nil Len = %d; want 0", m.Len()) + } + + m = NewShardedMap() + if !m.IsEmpty() { + t.Fatalf("Fresh ShardedMap IsEmpty should be true") + } +} + +func TestShardedMapIterate(t *testing.T) { + m := NewShardedMap() + for i := uint64(0); i < 100; i++ { + m.Set(i, Val{Tid: IntID, Value: int64(i)}) + } + count := 0 + err := m.Iterate(func(k uint64, v Val) error { + count++ + return nil + }) + if err != nil { + t.Fatal(err) + } + if count != 100 { + t.Fatalf("iter count = %d; want 100", count) + } +} + +func TestShardedMapMerge(t *testing.T) { + a := NewShardedMap() + b := NewShardedMap() + for i := uint64(0); i < 10; i++ { + a.Set(i, Val{Tid: IntID, Value: int64(i)}) + b.Set(i+5, Val{Tid: IntID, Value: int64(i + 100)}) + } + a.Merge(b, func(x, y Val) Val { return y }) + if a.Len() != 15 { + t.Fatalf("merged Len = %d; want 15", a.Len()) + } + got, _ := a.Get(7) + if got.Value != int64(102) { + t.Fatalf("Merged Get(7).Value = %v; want 102", got) + } +} + +// BenchmarkShardedMapNew measures the cost of constructing a fresh ShardedMap +// with no inserts. This is the case where lazy shard init pays off the most: +// no shard maps should be allocated. +func BenchmarkShardedMapNew(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + m := NewShardedMap() + _ = m + } +} + +// BenchmarkShardedMapSetGet measures a typical small-data pattern: a few +// inserts, a few reads, a Len/IsEmpty check. +func BenchmarkShardedMapSetGet(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + m := NewShardedMap() + for j := uint64(0); j < 8; j++ { + m.Set(j, Val{Tid: IntID, Value: int64(j)}) + } + for j := uint64(0); j < 8; j++ { + _, _ = m.Get(j) + } + } +} + +// BenchmarkShardedMapFull writes to all 30 shards — measures the worst case +// where lazy init provides no savings. +func BenchmarkShardedMapFull(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + m := NewShardedMap() + for j := uint64(0); j < 30; j++ { + m.Set(j, Val{Tid: IntID, Value: int64(j)}) + } + } +}