-
Notifications
You must be signed in to change notification settings - Fork 11
fix(sampling): align resize-factor behavior and replace reservoir panics with errors #120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,9 +21,7 @@ import ( | |
| "encoding/binary" | ||
| "errors" | ||
| "fmt" | ||
| "math" | ||
| "math/rand" | ||
| "slices" | ||
| "strings" | ||
|
|
||
| "github.com/apache/datasketches-go/common" | ||
|
|
@@ -44,11 +42,27 @@ const ( | |
|
|
||
| defaultResizeFactor = ResizeX8 | ||
| minK = 2 | ||
| maxItemsSeen = int64(0xFFFFFFFFFFFF) | ||
|
|
||
| // smallest sampling array allocated: 16 | ||
| minLgArrItems = 4 | ||
| ) | ||
|
|
||
| func resizeFactorLg(rf ResizeFactor) (int, error) { | ||
| switch rf { | ||
| case ResizeX1: | ||
| return 0, nil | ||
| case ResizeX2: | ||
| return 1, nil | ||
| case ResizeX4: | ||
| return 2, nil | ||
| case ResizeX8: | ||
| return 3, nil | ||
| default: | ||
| return 0, errors.New("unsupported resize factor") | ||
| } | ||
| } | ||
|
|
||
| // ReservoirItemsSketch provides a uniform random sample of items | ||
| // from a stream of unknown size using the reservoir sampling algorithm. | ||
| // | ||
|
|
@@ -93,9 +107,14 @@ func NewReservoirItemsSketch[T any]( | |
| opt(options) | ||
| } | ||
|
|
||
| lgRf, err := resizeFactorLg(options.resizeFactor) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| ceilingLgK, _ := internal.ExactLog2(common.CeilingPowerOf2(k)) | ||
| initialLgSize := startingSubMultiple( | ||
| ceilingLgK, int(math.Log2(float64(options.resizeFactor))), minLgArrItems, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. using ResizeFactor itself is ok , right? |
||
| ceilingLgK, lgRf, minLgArrItems, | ||
| ) | ||
| return &ReservoirItemsSketch[T]{ | ||
| k: k, | ||
|
|
@@ -106,27 +125,44 @@ func NewReservoirItemsSketch[T any]( | |
| } | ||
|
|
||
| // Update adds an item to the sketch using reservoir sampling algorithm. | ||
| func (s *ReservoirItemsSketch[T]) Update(item T) { | ||
| func (s *ReservoirItemsSketch[T]) Update(item T) error { | ||
| if s.n == maxItemsSeen { | ||
| return fmt.Errorf("sketch has exceeded capacity for total items seen: %d", maxItemsSeen) | ||
| } | ||
|
|
||
| if s.n < int64(s.k) { | ||
| // Initial phase: store all items until reservoir is full | ||
| if s.n >= int64(cap(s.data)) { | ||
| s.growReservoir() | ||
| if err := s.growReservoir(); err != nil { | ||
| return err | ||
| } | ||
| } | ||
|
|
||
| s.data = append(s.data, item) | ||
| s.n++ | ||
| } else { | ||
| // Steady state: replace with probability k/n | ||
| j := rand.Int63n(s.n + 1) | ||
| if j < int64(s.k) { | ||
| s.data[j] = item | ||
| s.n++ | ||
| if rand.Float64()*float64(s.n) < float64(s.k) { | ||
| s.data[rand.Intn(s.k)] = item | ||
| } | ||
| } | ||
| s.n++ | ||
| return nil | ||
| } | ||
|
|
||
| func (s *ReservoirItemsSketch[T]) growReservoir() { | ||
| adjustedSize := adjustedSamplingAllocationSize(s.k, cap(s.data)<<int(s.rf)) | ||
| s.data = slices.Grow(s.data, adjustedSize) | ||
| func (s *ReservoirItemsSketch[T]) growReservoir() error { | ||
| lgRf, err := resizeFactorLg(s.rf) | ||
| if err != nil { | ||
| return err | ||
| } | ||
| targetCap := adjustedSamplingAllocationSize(s.k, cap(s.data)<<lgRf) | ||
| if targetCap <= cap(s.data) { | ||
| return nil | ||
| } | ||
| newData := make([]T, len(s.data), targetCap) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. L162 ~ L164 is too expensive. that is why i use |
||
| copy(newData, s.data) | ||
| s.data = newData | ||
| return nil | ||
| } | ||
|
|
||
| // K returns the maximum reservoir capacity. | ||
|
|
@@ -158,9 +194,10 @@ func (s *ReservoirItemsSketch[T]) IsEmpty() bool { | |
|
|
||
| // Reset clears the sketch while preserving capacity k. | ||
| func (s *ReservoirItemsSketch[T]) Reset() { | ||
| lgRf, _ := resizeFactorLg(s.rf) | ||
| ceilingLgK, _ := internal.ExactLog2(common.CeilingPowerOf2(s.k)) | ||
| initialLgSize := startingSubMultiple( | ||
| ceilingLgK, int(math.Log2(float64(s.rf))), minLgArrItems, | ||
| ceilingLgK, lgRf, minLgArrItems, | ||
| ) | ||
|
|
||
| s.n = 0 | ||
|
|
@@ -220,18 +257,18 @@ func (s *ReservoirItemsSketch[T]) EstimateSubsetSum(predicate func(T) bool) (Sam | |
|
|
||
| lowerBoundTrueFraction, err := pseudoHypergeometricLowerBoundOnP(uint64(numSamples), uint64(trueCount), samplingRate) | ||
| if err != nil { | ||
| return SampleSubsetSummary{}, nil | ||
| return SampleSubsetSummary{}, err | ||
| } | ||
| upperBoundTrueFraction, err := pseudoHypergeometricUpperBoundOnP(uint64(numSamples), uint64(trueCount), samplingRate) | ||
| if err != nil { | ||
| return SampleSubsetSummary{}, nil | ||
| return SampleSubsetSummary{}, err | ||
| } | ||
| estimatedTrueFraction := (1.0 * float64(trueCount)) / float64(numSamples) | ||
| return SampleSubsetSummary{ | ||
| LowerBound: lowerBoundTrueFraction, | ||
| Estimate: estimatedTrueFraction, | ||
| UpperBound: upperBoundTrueFraction, | ||
| TotalSketchWeight: float64(numSamples), | ||
| LowerBound: float64(s.n) * lowerBoundTrueFraction, | ||
| Estimate: float64(s.n) * estimatedTrueFraction, | ||
| UpperBound: float64(s.n) * upperBoundTrueFraction, | ||
| TotalSketchWeight: float64(s.n), | ||
| }, nil | ||
| } | ||
|
|
||
|
|
@@ -249,12 +286,16 @@ func (s *ReservoirItemsSketch[T]) DownsampledCopy(newK int) (*ReservoirItemsSket | |
|
|
||
| samples := s.Samples() | ||
| for _, item := range samples { | ||
| result.Update(item) | ||
| if err := result.Update(item); err != nil { | ||
| return nil, err | ||
| } | ||
| } | ||
|
|
||
| // Adjust N to preserve correct implicit weights | ||
| if result.n < s.n { | ||
| result.forceIncrementItemsSeen(s.n - result.n) | ||
| if err := result.forceIncrementItemsSeen(s.n - result.n); err != nil { | ||
| return nil, err | ||
| } | ||
| } | ||
|
|
||
| return result, nil | ||
|
|
@@ -271,8 +312,15 @@ func (s *ReservoirItemsSketch[T]) insertValueAtPosition(item T, pos int) { | |
| } | ||
|
|
||
| // forceIncrementItemsSeen adds delta to the items seen count. | ||
| func (s *ReservoirItemsSketch[T]) forceIncrementItemsSeen(delta int64) { | ||
| func (s *ReservoirItemsSketch[T]) forceIncrementItemsSeen(delta int64) error { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you add
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right. I was wrong. chaning state first is correct. sketch's state is gone wrong already. |
||
| s.n += delta | ||
| if s.n > maxItemsSeen { | ||
| return fmt.Errorf( | ||
| "sketch has exceeded capacity for total items seen. limit: %d, found: %d", | ||
| maxItemsSeen, s.n, | ||
| ) | ||
| } | ||
| return nil | ||
| } | ||
|
|
||
| // Serialization constants | ||
|
|
@@ -417,6 +465,9 @@ func NewReservoirItemsSketchFromSlice[T any](data []byte, serde ItemsSerDe[T]) ( | |
| } | ||
|
|
||
| n := int64(binary.LittleEndian.Uint64(data[8:])) | ||
| if n > maxItemsSeen { | ||
| return nil, fmt.Errorf("items seen exceeds limit: %d", maxItemsSeen) | ||
| } | ||
| numSamples := int(min(n, int64(k))) | ||
|
|
||
| itemsData := data[preambleBytes:] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,6 @@ package sampling | |
|
|
||
| import ( | ||
| "encoding/binary" | ||
| "math" | ||
| "math/rand" | ||
| "testing" | ||
|
|
||
|
|
@@ -42,8 +41,8 @@ func TestReservoirItemsSketchWithStrings(t *testing.T) { | |
| assert.NoError(t, err) | ||
|
|
||
| sketch.Update("apple") | ||
| sketch.Update("banana") | ||
| sketch.Update("cherry") | ||
| _ = sketch.Update("banana") | ||
| _ = sketch.Update("cherry") | ||
|
|
||
| assert.Equal(t, int64(3), sketch.N()) | ||
| assert.Equal(t, 3, sketch.NumSamples()) | ||
|
|
@@ -63,9 +62,9 @@ func TestReservoirItemsSketchWithStruct(t *testing.T) { | |
| sketch, err := NewReservoirItemsSketch[Event](5) | ||
| assert.NoError(t, err) | ||
|
|
||
| sketch.Update(Event{1, "login"}) | ||
| sketch.Update(Event{2, "logout"}) | ||
| sketch.Update(Event{3, "click"}) | ||
| _ = sketch.Update(Event{1, "login"}) | ||
| _ = sketch.Update(Event{2, "logout"}) | ||
| _ = sketch.Update(Event{3, "click"}) | ||
|
|
||
| assert.Equal(t, int64(3), sketch.N()) | ||
| samples := sketch.Samples() | ||
|
|
@@ -78,6 +77,9 @@ func TestReservoirItemsSketchInvalidK(t *testing.T) { | |
|
|
||
| _, err = NewReservoirItemsSketch[int64](1) | ||
| assert.ErrorContains(t, err, "k must be at least 2") | ||
|
|
||
| _, err = NewReservoirItemsSketch[int64](16, WithReservoirItemsSketchResizeFactor(ResizeFactor(3))) | ||
| assert.ErrorContains(t, err, "unsupported resize factor") | ||
| } | ||
|
|
||
| func TestReservoirItemsSketch_Update(t *testing.T) { | ||
|
|
@@ -150,8 +152,9 @@ func TestReservoirItemsSketchReset(t *testing.T) { | |
| assert.NoError(t, err) | ||
|
|
||
| ceilingLgK, _ := internal.ExactLog2(common.CeilingPowerOf2(k)) | ||
| lgRf, _ := resizeFactorLg(defaultResizeFactor) | ||
| initialLgSize := startingSubMultiple( | ||
| ceilingLgK, int(math.Log2(float64(defaultResizeFactor))), minLgArrItems, | ||
| ceilingLgK, lgRf, minLgArrItems, | ||
| ) | ||
| expectedInitialCap := adjustedSamplingAllocationSize(k, 1<<initialLgSize) | ||
|
|
||
|
|
@@ -199,6 +202,45 @@ func TestReservoirItemsSketchResizeFactorSerialization(t *testing.T) { | |
| assert.Equal(t, ResizeX2, restored.rf) | ||
| } | ||
|
|
||
| func TestReservoirItemsSketchResizeFactorGrowth(t *testing.T) { | ||
| tests := []struct { | ||
| name string | ||
| rf ResizeFactor | ||
| wantLgRf int | ||
| }{ | ||
| {name: "X1", rf: ResizeX1, wantLgRf: 0}, | ||
| {name: "X2", rf: ResizeX2, wantLgRf: 1}, | ||
| {name: "X4", rf: ResizeX4, wantLgRf: 2}, | ||
| {name: "X8", rf: ResizeX8, wantLgRf: 3}, | ||
| } | ||
|
|
||
| for _, tc := range tests { | ||
| tc := tc | ||
| t.Run(tc.name, func(t *testing.T) { | ||
| lgK := 8 // k=256 gives room for growth before clamping | ||
| k := 1 << lgK | ||
| minLg := minLgArrItems | ||
|
|
||
| sketch, err := NewReservoirItemsSketch[int64](k, WithReservoirItemsSketchResizeFactor(tc.rf)) | ||
| assert.NoError(t, err) | ||
|
|
||
| initialLg := startingSubMultiple(lgK, tc.wantLgRf, minLg) | ||
| initialCap := adjustedSamplingAllocationSize(k, 1<<initialLg) | ||
|
|
||
| assert.Equal(t, initialCap, cap(sketch.data)) | ||
|
|
||
| // Fill to initial capacity then trigger one growth. | ||
| for i := 0; i < initialCap; i++ { | ||
| sketch.Update(int64(i)) | ||
| } | ||
| sketch.Update(int64(initialCap)) | ||
|
|
||
| expectedTarget := adjustedSamplingAllocationSize(k, initialCap<<tc.wantLgRf) | ||
| assert.GreaterOrEqual(t, cap(sketch.data), expectedTarget) | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| func TestReservoirItemsSketchEstimateSubsetSum(t *testing.T) { | ||
| t.Run("EmptySketch", func(t *testing.T) { | ||
| sketch, err := NewReservoirItemsSketch[int64](10) | ||
|
|
@@ -239,8 +281,8 @@ func TestReservoirItemsSketchEstimateSubsetSum(t *testing.T) { | |
| assert.NoError(t, err) | ||
| assert.Equal(t, 0.0, summary.Estimate) | ||
| assert.Equal(t, 0.0, summary.LowerBound) | ||
| assert.True(t, summary.UpperBound >= 0.0 && summary.UpperBound <= 1.0) | ||
| assert.Equal(t, float64(sketch.NumSamples()), summary.TotalSketchWeight) | ||
| assert.True(t, summary.UpperBound >= 0.0 && summary.UpperBound <= float64(sketch.N())) | ||
| assert.Equal(t, float64(sketch.N()), summary.TotalSketchWeight) | ||
| }) | ||
|
|
||
| t.Run("EstimationModePredicateAlwaysMatches", func(t *testing.T) { | ||
|
|
@@ -253,10 +295,10 @@ func TestReservoirItemsSketchEstimateSubsetSum(t *testing.T) { | |
|
|
||
| summary, err := sketch.EstimateSubsetSum(func(int64) bool { return true }) | ||
| assert.NoError(t, err) | ||
| assert.Equal(t, 1.0, summary.Estimate) | ||
| assert.Equal(t, 1.0, summary.UpperBound) | ||
| assert.True(t, summary.LowerBound >= 0.0 && summary.LowerBound <= 1.0) | ||
| assert.Equal(t, float64(sketch.NumSamples()), summary.TotalSketchWeight) | ||
| assert.Equal(t, float64(sketch.N()), summary.Estimate) | ||
| assert.Equal(t, float64(sketch.N()), summary.UpperBound) | ||
| assert.True(t, summary.LowerBound >= 0.0 && summary.LowerBound <= float64(sketch.N())) | ||
| assert.Equal(t, float64(sketch.N()), summary.TotalSketchWeight) | ||
| }) | ||
|
|
||
| t.Run("EstimationModePredicatePartiallyMatches", func(t *testing.T) { | ||
|
|
@@ -274,16 +316,16 @@ func TestReservoirItemsSketchEstimateSubsetSum(t *testing.T) { | |
| trueCount++ | ||
| } | ||
| } | ||
| expectedEstimate := float64(trueCount) / float64(len(samples)) | ||
| expectedEstimate := float64(sketch.N()) * (float64(trueCount) / float64(len(samples))) | ||
|
|
||
| summary, err := sketch.EstimateSubsetSum(func(v int64) bool { return v%2 == 0 }) | ||
| assert.NoError(t, err) | ||
| assert.InDelta(t, expectedEstimate, summary.Estimate, 0.0) | ||
| assert.True(t, summary.LowerBound >= 0.0 && summary.LowerBound <= 1.0) | ||
| assert.True(t, summary.UpperBound >= 0.0 && summary.UpperBound <= 1.0) | ||
| assert.True(t, summary.LowerBound >= 0.0 && summary.LowerBound <= float64(sketch.N())) | ||
| assert.True(t, summary.UpperBound >= 0.0 && summary.UpperBound <= float64(sketch.N())) | ||
| assert.True(t, summary.LowerBound <= summary.Estimate) | ||
| assert.True(t, summary.Estimate <= summary.UpperBound) | ||
| assert.Equal(t, float64(sketch.NumSamples()), summary.TotalSketchWeight) | ||
| assert.Equal(t, float64(sketch.N()), summary.TotalSketchWeight) | ||
| }) | ||
| } | ||
|
|
||
|
|
@@ -301,3 +343,34 @@ func TestReservoirItemsSketchLegacySerVerEmpty(t *testing.T) { | |
| assert.Equal(t, 1024, sketch.K()) | ||
| assert.Equal(t, ResizeX8, sketch.rf) | ||
| } | ||
|
|
||
| func TestReservoirItemsSketchUpdateReturnsErrorAtMaxItemsSeen(t *testing.T) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this case is included already defined test function as nest case. |
||
| sketch, err := NewReservoirItemsSketch[int64](8) | ||
| assert.NoError(t, err) | ||
| sketch.n = maxItemsSeen | ||
|
|
||
| err = sketch.Update(1) | ||
| assert.ErrorContains(t, err, "sketch has exceeded capacity") | ||
| } | ||
|
|
||
| func TestReservoirItemsSketchForceIncrementItemsSeenReturnsErrorOnOverflow(t *testing.T) { | ||
| sketch, err := NewReservoirItemsSketch[int64](8) | ||
| assert.NoError(t, err) | ||
| sketch.n = maxItemsSeen - 1 | ||
|
|
||
| err = sketch.forceIncrementItemsSeen(2) | ||
| assert.ErrorContains(t, err, "sketch has exceeded capacity") | ||
| } | ||
|
|
||
| func TestReservoirItemsSketchFromSliceRejectsNTooLarge(t *testing.T) { | ||
| data := make([]byte, 16) | ||
| data[0] = 0xC0 | preambleIntsNonEmpty | ||
| data[1] = serVer | ||
| data[2] = byte(internal.FamilyEnum.ReservoirItems.Id) | ||
| data[3] = 0 | ||
| binary.LittleEndian.PutUint32(data[4:], uint32(8)) | ||
| binary.LittleEndian.PutUint64(data[8:], uint64(maxItemsSeen+1)) | ||
|
|
||
| _, err := NewReservoirItemsSketchFromSlice[int64](data, Int64SerDe{}) | ||
| assert.ErrorContains(t, err, "items seen exceeds limit") | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can define ResizeFactor using iota. so, verbose
casesyntax doesn't need.