Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 73 additions & 22 deletions sampling/reservoir_items_sketch.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"math"
"math/rand"
"slices"
"strings"

"github.com/apache/datasketches-go/common"
Expand All @@ -44,11 +42,27 @@ const (

defaultResizeFactor = ResizeX8
minK = 2
maxItemsSeen = int64(0xFFFFFFFFFFFF)

// smallest sampling array allocated: 16
minLgArrItems = 4
)

func resizeFactorLg(rf ResizeFactor) (int, error) {
Copy link
Member

@proost proost Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can define ResizeFactor using iota. so, verbose case syntax doesn't need.

switch rf {
case ResizeX1:
return 0, nil
case ResizeX2:
return 1, nil
case ResizeX4:
return 2, nil
case ResizeX8:
return 3, nil
default:
return 0, errors.New("unsupported resize factor")
}
}

// ReservoirItemsSketch provides a uniform random sample of items
// from a stream of unknown size using the reservoir sampling algorithm.
//
Expand Down Expand Up @@ -93,9 +107,14 @@ func NewReservoirItemsSketch[T any](
opt(options)
}

lgRf, err := resizeFactorLg(options.resizeFactor)
if err != nil {
return nil, err
}

ceilingLgK, _ := internal.ExactLog2(common.CeilingPowerOf2(k))
initialLgSize := startingSubMultiple(
ceilingLgK, int(math.Log2(float64(options.resizeFactor))), minLgArrItems,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using ResizeFactor itself is ok , right?

ceilingLgK, lgRf, minLgArrItems,
)
return &ReservoirItemsSketch[T]{
k: k,
Expand All @@ -106,27 +125,44 @@ func NewReservoirItemsSketch[T any](
}

// Update adds an item to the sketch using reservoir sampling algorithm.
func (s *ReservoirItemsSketch[T]) Update(item T) {
func (s *ReservoirItemsSketch[T]) Update(item T) error {
if s.n == maxItemsSeen {
return fmt.Errorf("sketch has exceeded capacity for total items seen: %d", maxItemsSeen)
}

if s.n < int64(s.k) {
// Initial phase: store all items until reservoir is full
if s.n >= int64(cap(s.data)) {
s.growReservoir()
if err := s.growReservoir(); err != nil {
return err
}
}

s.data = append(s.data, item)
s.n++
} else {
// Steady state: replace with probability k/n
j := rand.Int63n(s.n + 1)
if j < int64(s.k) {
s.data[j] = item
s.n++
if rand.Float64()*float64(s.n) < float64(s.k) {
s.data[rand.Intn(s.k)] = item
}
}
s.n++
return nil
}

func (s *ReservoirItemsSketch[T]) growReservoir() {
adjustedSize := adjustedSamplingAllocationSize(s.k, cap(s.data)<<int(s.rf))
s.data = slices.Grow(s.data, adjustedSize)
func (s *ReservoirItemsSketch[T]) growReservoir() error {
lgRf, err := resizeFactorLg(s.rf)
if err != nil {
return err
}
targetCap := adjustedSamplingAllocationSize(s.k, cap(s.data)<<lgRf)
if targetCap <= cap(s.data) {
return nil
}
newData := make([]T, len(s.data), targetCap)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L162 ~ L164 is too expensive. that is why i use Grow

copy(newData, s.data)
s.data = newData
return nil
}

// K returns the maximum reservoir capacity.
Expand Down Expand Up @@ -158,9 +194,10 @@ func (s *ReservoirItemsSketch[T]) IsEmpty() bool {

// Reset clears the sketch while preserving capacity k.
func (s *ReservoirItemsSketch[T]) Reset() {
lgRf, _ := resizeFactorLg(s.rf)
ceilingLgK, _ := internal.ExactLog2(common.CeilingPowerOf2(s.k))
initialLgSize := startingSubMultiple(
ceilingLgK, int(math.Log2(float64(s.rf))), minLgArrItems,
ceilingLgK, lgRf, minLgArrItems,
)

s.n = 0
Expand Down Expand Up @@ -220,18 +257,18 @@ func (s *ReservoirItemsSketch[T]) EstimateSubsetSum(predicate func(T) bool) (Sam

lowerBoundTrueFraction, err := pseudoHypergeometricLowerBoundOnP(uint64(numSamples), uint64(trueCount), samplingRate)
if err != nil {
return SampleSubsetSummary{}, nil
return SampleSubsetSummary{}, err
}
upperBoundTrueFraction, err := pseudoHypergeometricUpperBoundOnP(uint64(numSamples), uint64(trueCount), samplingRate)
if err != nil {
return SampleSubsetSummary{}, nil
return SampleSubsetSummary{}, err
}
estimatedTrueFraction := (1.0 * float64(trueCount)) / float64(numSamples)
return SampleSubsetSummary{
LowerBound: lowerBoundTrueFraction,
Estimate: estimatedTrueFraction,
UpperBound: upperBoundTrueFraction,
TotalSketchWeight: float64(numSamples),
LowerBound: float64(s.n) * lowerBoundTrueFraction,
Estimate: float64(s.n) * estimatedTrueFraction,
UpperBound: float64(s.n) * upperBoundTrueFraction,
TotalSketchWeight: float64(s.n),
}, nil
}

Expand All @@ -249,12 +286,16 @@ func (s *ReservoirItemsSketch[T]) DownsampledCopy(newK int) (*ReservoirItemsSket

samples := s.Samples()
for _, item := range samples {
result.Update(item)
if err := result.Update(item); err != nil {
return nil, err
}
}

// Adjust N to preserve correct implicit weights
if result.n < s.n {
result.forceIncrementItemsSeen(s.n - result.n)
if err := result.forceIncrementItemsSeen(s.n - result.n); err != nil {
return nil, err
}
}

return result, nil
Expand All @@ -271,8 +312,15 @@ func (s *ReservoirItemsSketch[T]) insertValueAtPosition(item T, pos int) {
}

// forceIncrementItemsSeen adds delta to the items seen count.
func (s *ReservoirItemsSketch[T]) forceIncrementItemsSeen(delta int64) {
func (s *ReservoirItemsSketch[T]) forceIncrementItemsSeen(delta int64) error {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you add delta to s.n first, it already change state. so guarding not to above maxItemsSeen is correct.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. I was wrong. chaning state first is correct. sketch's state is gone wrong already.

s.n += delta
if s.n > maxItemsSeen {
return fmt.Errorf(
"sketch has exceeded capacity for total items seen. limit: %d, found: %d",
maxItemsSeen, s.n,
)
}
return nil
}

// Serialization constants
Expand Down Expand Up @@ -417,6 +465,9 @@ func NewReservoirItemsSketchFromSlice[T any](data []byte, serde ItemsSerDe[T]) (
}

n := int64(binary.LittleEndian.Uint64(data[8:]))
if n > maxItemsSeen {
return nil, fmt.Errorf("items seen exceeds limit: %d", maxItemsSeen)
}
numSamples := int(min(n, int64(k)))

itemsData := data[preambleBytes:]
Expand Down
107 changes: 90 additions & 17 deletions sampling/reservoir_items_sketch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package sampling

import (
"encoding/binary"
"math"
"math/rand"
"testing"

Expand All @@ -42,8 +41,8 @@ func TestReservoirItemsSketchWithStrings(t *testing.T) {
assert.NoError(t, err)

sketch.Update("apple")
sketch.Update("banana")
sketch.Update("cherry")
_ = sketch.Update("banana")
_ = sketch.Update("cherry")

assert.Equal(t, int64(3), sketch.N())
assert.Equal(t, 3, sketch.NumSamples())
Expand All @@ -63,9 +62,9 @@ func TestReservoirItemsSketchWithStruct(t *testing.T) {
sketch, err := NewReservoirItemsSketch[Event](5)
assert.NoError(t, err)

sketch.Update(Event{1, "login"})
sketch.Update(Event{2, "logout"})
sketch.Update(Event{3, "click"})
_ = sketch.Update(Event{1, "login"})
_ = sketch.Update(Event{2, "logout"})
_ = sketch.Update(Event{3, "click"})

assert.Equal(t, int64(3), sketch.N())
samples := sketch.Samples()
Expand All @@ -78,6 +77,9 @@ func TestReservoirItemsSketchInvalidK(t *testing.T) {

_, err = NewReservoirItemsSketch[int64](1)
assert.ErrorContains(t, err, "k must be at least 2")

_, err = NewReservoirItemsSketch[int64](16, WithReservoirItemsSketchResizeFactor(ResizeFactor(3)))
assert.ErrorContains(t, err, "unsupported resize factor")
}

func TestReservoirItemsSketch_Update(t *testing.T) {
Expand Down Expand Up @@ -150,8 +152,9 @@ func TestReservoirItemsSketchReset(t *testing.T) {
assert.NoError(t, err)

ceilingLgK, _ := internal.ExactLog2(common.CeilingPowerOf2(k))
lgRf, _ := resizeFactorLg(defaultResizeFactor)
initialLgSize := startingSubMultiple(
ceilingLgK, int(math.Log2(float64(defaultResizeFactor))), minLgArrItems,
ceilingLgK, lgRf, minLgArrItems,
)
expectedInitialCap := adjustedSamplingAllocationSize(k, 1<<initialLgSize)

Expand Down Expand Up @@ -199,6 +202,45 @@ func TestReservoirItemsSketchResizeFactorSerialization(t *testing.T) {
assert.Equal(t, ResizeX2, restored.rf)
}

func TestReservoirItemsSketchResizeFactorGrowth(t *testing.T) {
tests := []struct {
name string
rf ResizeFactor
wantLgRf int
}{
{name: "X1", rf: ResizeX1, wantLgRf: 0},
{name: "X2", rf: ResizeX2, wantLgRf: 1},
{name: "X4", rf: ResizeX4, wantLgRf: 2},
{name: "X8", rf: ResizeX8, wantLgRf: 3},
}

for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
lgK := 8 // k=256 gives room for growth before clamping
k := 1 << lgK
minLg := minLgArrItems

sketch, err := NewReservoirItemsSketch[int64](k, WithReservoirItemsSketchResizeFactor(tc.rf))
assert.NoError(t, err)

initialLg := startingSubMultiple(lgK, tc.wantLgRf, minLg)
initialCap := adjustedSamplingAllocationSize(k, 1<<initialLg)

assert.Equal(t, initialCap, cap(sketch.data))

// Fill to initial capacity then trigger one growth.
for i := 0; i < initialCap; i++ {
sketch.Update(int64(i))
}
sketch.Update(int64(initialCap))

expectedTarget := adjustedSamplingAllocationSize(k, initialCap<<tc.wantLgRf)
assert.GreaterOrEqual(t, cap(sketch.data), expectedTarget)
})
}
}

func TestReservoirItemsSketchEstimateSubsetSum(t *testing.T) {
t.Run("EmptySketch", func(t *testing.T) {
sketch, err := NewReservoirItemsSketch[int64](10)
Expand Down Expand Up @@ -239,8 +281,8 @@ func TestReservoirItemsSketchEstimateSubsetSum(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, 0.0, summary.Estimate)
assert.Equal(t, 0.0, summary.LowerBound)
assert.True(t, summary.UpperBound >= 0.0 && summary.UpperBound <= 1.0)
assert.Equal(t, float64(sketch.NumSamples()), summary.TotalSketchWeight)
assert.True(t, summary.UpperBound >= 0.0 && summary.UpperBound <= float64(sketch.N()))
assert.Equal(t, float64(sketch.N()), summary.TotalSketchWeight)
})

t.Run("EstimationModePredicateAlwaysMatches", func(t *testing.T) {
Expand All @@ -253,10 +295,10 @@ func TestReservoirItemsSketchEstimateSubsetSum(t *testing.T) {

summary, err := sketch.EstimateSubsetSum(func(int64) bool { return true })
assert.NoError(t, err)
assert.Equal(t, 1.0, summary.Estimate)
assert.Equal(t, 1.0, summary.UpperBound)
assert.True(t, summary.LowerBound >= 0.0 && summary.LowerBound <= 1.0)
assert.Equal(t, float64(sketch.NumSamples()), summary.TotalSketchWeight)
assert.Equal(t, float64(sketch.N()), summary.Estimate)
assert.Equal(t, float64(sketch.N()), summary.UpperBound)
assert.True(t, summary.LowerBound >= 0.0 && summary.LowerBound <= float64(sketch.N()))
assert.Equal(t, float64(sketch.N()), summary.TotalSketchWeight)
})

t.Run("EstimationModePredicatePartiallyMatches", func(t *testing.T) {
Expand All @@ -274,16 +316,16 @@ func TestReservoirItemsSketchEstimateSubsetSum(t *testing.T) {
trueCount++
}
}
expectedEstimate := float64(trueCount) / float64(len(samples))
expectedEstimate := float64(sketch.N()) * (float64(trueCount) / float64(len(samples)))

summary, err := sketch.EstimateSubsetSum(func(v int64) bool { return v%2 == 0 })
assert.NoError(t, err)
assert.InDelta(t, expectedEstimate, summary.Estimate, 0.0)
assert.True(t, summary.LowerBound >= 0.0 && summary.LowerBound <= 1.0)
assert.True(t, summary.UpperBound >= 0.0 && summary.UpperBound <= 1.0)
assert.True(t, summary.LowerBound >= 0.0 && summary.LowerBound <= float64(sketch.N()))
assert.True(t, summary.UpperBound >= 0.0 && summary.UpperBound <= float64(sketch.N()))
assert.True(t, summary.LowerBound <= summary.Estimate)
assert.True(t, summary.Estimate <= summary.UpperBound)
assert.Equal(t, float64(sketch.NumSamples()), summary.TotalSketchWeight)
assert.Equal(t, float64(sketch.N()), summary.TotalSketchWeight)
})
}

Expand All @@ -301,3 +343,34 @@ func TestReservoirItemsSketchLegacySerVerEmpty(t *testing.T) {
assert.Equal(t, 1024, sketch.K())
assert.Equal(t, ResizeX8, sketch.rf)
}

func TestReservoirItemsSketchUpdateReturnsErrorAtMaxItemsSeen(t *testing.T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this case is included already defined test function as nest case.

sketch, err := NewReservoirItemsSketch[int64](8)
assert.NoError(t, err)
sketch.n = maxItemsSeen

err = sketch.Update(1)
assert.ErrorContains(t, err, "sketch has exceeded capacity")
}

func TestReservoirItemsSketchForceIncrementItemsSeenReturnsErrorOnOverflow(t *testing.T) {
sketch, err := NewReservoirItemsSketch[int64](8)
assert.NoError(t, err)
sketch.n = maxItemsSeen - 1

err = sketch.forceIncrementItemsSeen(2)
assert.ErrorContains(t, err, "sketch has exceeded capacity")
}

func TestReservoirItemsSketchFromSliceRejectsNTooLarge(t *testing.T) {
data := make([]byte, 16)
data[0] = 0xC0 | preambleIntsNonEmpty
data[1] = serVer
data[2] = byte(internal.FamilyEnum.ReservoirItems.Id)
data[3] = 0
binary.LittleEndian.PutUint32(data[4:], uint32(8))
binary.LittleEndian.PutUint64(data[8:], uint64(maxItemsSeen+1))

_, err := NewReservoirItemsSketchFromSlice[int64](data, Int64SerDe{})
assert.ErrorContains(t, err, "items seen exceeds limit")
}
Loading