diff --git a/internal/family.go b/internal/family.go index 45fb4df..bd3f392 100644 --- a/internal/family.go +++ b/internal/family.go @@ -34,6 +34,7 @@ type families struct { TDigest family ReservoirItems family VarOptItems family + VarOptUnion family ReservoirUnion family } @@ -82,6 +83,10 @@ var FamilyEnum = &families{ Id: 13, MaxPreLongs: 4, }, + VarOptUnion: family{ + Id: 14, + MaxPreLongs: 4, + }, ReservoirUnion: family{ Id: 12, MaxPreLongs: 1, diff --git a/sampling/compatibility_test.go b/sampling/compatibility_test.go index d51b7c2..87a5725 100644 --- a/sampling/compatibility_test.go +++ b/sampling/compatibility_test.go @@ -247,8 +247,403 @@ func TestGenerateGoBinariesForCompatibilityTesting(t *testing.T) { }) } }) + + // ========== VarOptItemsSketch (8 files) ========== + // Matches Java/C++ VarOptCrossLanguageTest / serialize_for_java scenarios. + t.Run("varopt_sketch_long", func(t *testing.T) { + nArr := []int{0, 1, 10, 100, 1000, 10000, 100000, 1000000} + for _, n := range nArr { + n := n + t.Run(fmt.Sprintf("n%d", n), func(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int64](32) + assert.NoError(t, err) + for i := 1; i <= n; i++ { + assert.NoError(t, sketch.Update(int64(i), 1.0)) + } + data, err := sketch.ToSlice(Int64SerDe{}) + assert.NoError(t, err) + assert.NoError(t, os.WriteFile( + fmt.Sprintf("%s/varopt_sketch_long_n%d_go.sk", internal.GoPath, n), + data, + 0644, + )) + }) + } + }) + + // ========== VarOptItemsSketch exact (1 file) ========== + t.Run("varopt_sketch_string_exact", func(t *testing.T) { + sketch, err := NewVarOptItemsSketch[string](1024) + assert.NoError(t, err) + for i := 1; i <= 200; i++ { + assert.NoError(t, sketch.Update(fmt.Sprintf("%d", i), 1000.0/float64(i))) + } + data, err := sketch.ToSlice(StringSerDe{}) + assert.NoError(t, err) + assert.NoError(t, os.WriteFile( + fmt.Sprintf("%s/varopt_sketch_string_exact_go.sk", internal.GoPath), + data, + 0644, + )) + }) + + // ========== VarOptItemsSketch sampling (1 file) ========== + t.Run("varopt_sketch_long_sampling", func(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int64](1024) + assert.NoError(t, err) + for i := int64(0); i < 2000; i++ { + assert.NoError(t, sketch.Update(i, 1.0)) + } + // Negative heavy items to allow predicate filtering, aligned with Java/C++. + assert.NoError(t, sketch.Update(-1, 100000.0)) + assert.NoError(t, sketch.Update(-2, 110000.0)) + assert.NoError(t, sketch.Update(-3, 120000.0)) + + data, err := sketch.ToSlice(Int64SerDe{}) + assert.NoError(t, err) + assert.NoError(t, os.WriteFile( + fmt.Sprintf("%s/varopt_sketch_long_sampling_go.sk", internal.GoPath), + data, + 0644, + )) + }) + + // ========== VarOptItemsUnion sampling (1 file) ========== + t.Run("varopt_union_double_sampling", func(t *testing.T) { + const ( + kSmall = 16 + kMax = 128 + n1 = 32 + n2 = 64 + ) + + // Small-k sketch in sampling mode. + sketch1, err := NewVarOptItemsSketch[float64](kSmall) + assert.NoError(t, err) + for i := 0; i < n1; i++ { + assert.NoError(t, sketch1.Update(float64(i), 1.0)) + } + assert.NoError(t, sketch1.Update(-1.0, float64(n1*n1))) // negative heavy item + + // Another sketch with different n to yield a different implicit per-item weight. + sketch2, err := NewVarOptItemsSketch[float64](kSmall) + assert.NoError(t, err) + for i := 0; i < n2; i++ { + assert.NoError(t, sketch2.Update(float64(i), 1.0)) + } + + union, err := NewVarOptItemsUnion[float64](kMax) + assert.NoError(t, err) + assert.NoError(t, union.UpdateSketch(sketch1)) + assert.NoError(t, union.UpdateSketch(sketch2)) + + data, err := union.ToSlice(Float64SerDe{}) + assert.NoError(t, err) + assert.NoError(t, os.WriteFile( + fmt.Sprintf("%s/varopt_union_double_sampling_go.sk", internal.GoPath), + data, + 0644, + )) + }) } +func TestVarOptItemsSketch_JavaCompat(t *testing.T) { + nArr := []int{0, 1, 10, 100, 1000, 10000, 100000, 1000000} + for _, n := range nArr { + t.Run(fmt.Sprintf("long_n%d", n), func(t *testing.T) { + path := filepath.Join(internal.JavaPath, fmt.Sprintf("varopt_sketch_long_n%d_java.sk", n)) + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Skipf("Java file not found: %s", path) + return + } + + data, err := os.ReadFile(path) + assert.NoError(t, err) + + sketch, err := NewVarOptItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.NoError(t, err) + assert.Equal(t, 32, sketch.K()) + assert.Equal(t, int64(n), sketch.N()) + assert.Equal(t, min(n, 32), sketch.NumSamples()) + }) + } + + t.Run("string_exact", func(t *testing.T) { + path := filepath.Join(internal.JavaPath, "varopt_sketch_string_exact_java.sk") + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Skipf("Java file not found: %s", path) + return + } + + data, err := os.ReadFile(path) + assert.NoError(t, err) + + sketch, err := NewVarOptItemsSketchFromSlice[string](data, StringSerDe{}) + assert.NoError(t, err) + assert.Equal(t, 1024, sketch.K()) + assert.Equal(t, int64(200), sketch.N()) + assert.Equal(t, 200, sketch.NumSamples()) + + ss, err := sketch.EstimateSubsetSum(func(_ string) bool { return true }) + assert.NoError(t, err) + weight := 0.0 + for i := 1; i <= 200; i++ { + weight += 1000.0 / float64(i) + } + assert.InDelta(t, weight, ss.TotalSketchWeight, 1e-9) + }) + + t.Run("long_sampling", func(t *testing.T) { + path := filepath.Join(internal.JavaPath, "varopt_sketch_long_sampling_java.sk") + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Skipf("Java file not found: %s", path) + return + } + + data, err := os.ReadFile(path) + assert.NoError(t, err) + + sketch, err := NewVarOptItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.NoError(t, err) + assert.Equal(t, 1024, sketch.K()) + assert.Equal(t, int64(2003), sketch.N()) + assert.Equal(t, 1024, sketch.NumSamples()) + + ssAll, err := sketch.EstimateSubsetSum(func(_ int64) bool { return true }) + assert.NoError(t, err) + assert.InDelta(t, 332000.0, ssAll.TotalSketchWeight, 1e-9) + + ssNeg, err := sketch.EstimateSubsetSum(func(v int64) bool { return v < 0 }) + assert.NoError(t, err) + assert.InDelta(t, 330000.0, ssNeg.Estimate, 1e-9) + + ssNonNeg, err := sketch.EstimateSubsetSum(func(v int64) bool { return v >= 0 }) + assert.NoError(t, err) + assert.InDelta(t, 2000.0, ssNonNeg.Estimate, 1e-9) + }) +} + +func TestVarOptItemsSketch_CppCompat(t *testing.T) { + nArr := []int{0, 1, 10, 100, 1000, 10000, 100000, 1000000} + for _, n := range nArr { + t.Run(fmt.Sprintf("long_n%d", n), func(t *testing.T) { + path := filepath.Join(internal.CppPath, fmt.Sprintf("varopt_sketch_long_n%d_cpp.sk", n)) + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Skipf("CPP file not found: %s", path) + return + } + + data, err := os.ReadFile(path) + assert.NoError(t, err) + + sketch, err := NewVarOptItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.NoError(t, err) + assert.Equal(t, 32, sketch.K()) + assert.Equal(t, int64(n), sketch.N()) + assert.Equal(t, min(n, 32), sketch.NumSamples()) + }) + } + + t.Run("string_exact", func(t *testing.T) { + path := filepath.Join(internal.CppPath, "varopt_sketch_string_exact_cpp.sk") + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Skipf("CPP file not found: %s", path) + return + } + + data, err := os.ReadFile(path) + assert.NoError(t, err) + + sketch, err := NewVarOptItemsSketchFromSlice[string](data, StringSerDe{}) + assert.NoError(t, err) + assert.Equal(t, 1024, sketch.K()) + assert.Equal(t, int64(200), sketch.N()) + assert.Equal(t, 200, sketch.NumSamples()) + + ss, err := sketch.EstimateSubsetSum(func(_ string) bool { return true }) + assert.NoError(t, err) + weight := 0.0 + for i := 1; i <= 200; i++ { + weight += 1000.0 / float64(i) + } + assert.InDelta(t, weight, ss.TotalSketchWeight, 1e-9) + }) + + t.Run("long_sampling", func(t *testing.T) { + path := filepath.Join(internal.CppPath, "varopt_sketch_long_sampling_cpp.sk") + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Skipf("CPP file not found: %s", path) + return + } + + data, err := os.ReadFile(path) + assert.NoError(t, err) + + sketch, err := NewVarOptItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.NoError(t, err) + assert.Equal(t, 1024, sketch.K()) + assert.Equal(t, int64(2003), sketch.N()) + assert.Equal(t, 1024, sketch.NumSamples()) + + ssAll, err := sketch.EstimateSubsetSum(func(_ int64) bool { return true }) + assert.NoError(t, err) + assert.InDelta(t, 332000.0, ssAll.TotalSketchWeight, 1e-9) + + ssNeg, err := sketch.EstimateSubsetSum(func(v int64) bool { return v < 0 }) + assert.NoError(t, err) + assert.InDelta(t, 330000.0, ssNeg.Estimate, 1e-9) + + ssNonNeg, err := sketch.EstimateSubsetSum(func(v int64) bool { return v >= 0 }) + assert.NoError(t, err) + assert.InDelta(t, 2000.0, ssNonNeg.Estimate, 1e-9) + }) +} + +func TestVarOptItemsSketch_GoCompat(t *testing.T) { + nArr := []int{0, 1, 10, 100, 1000, 10000, 100000, 1000000} + for _, n := range nArr { + t.Run(fmt.Sprintf("long_n%d", n), func(t *testing.T) { + path := filepath.Join(internal.GoPath, fmt.Sprintf("varopt_sketch_long_n%d_go.sk", n)) + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Skipf("Go file not found: %s", path) + return + } + + data, err := os.ReadFile(path) + assert.NoError(t, err) + + sketch, err := NewVarOptItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.NoError(t, err) + assert.Equal(t, 32, sketch.K()) + assert.Equal(t, int64(n), sketch.N()) + assert.Equal(t, min(n, 32), sketch.NumSamples()) + }) + } + + t.Run("string_exact", func(t *testing.T) { + path := filepath.Join(internal.GoPath, "varopt_sketch_string_exact_go.sk") + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Skipf("Go file not found: %s", path) + return + } + + data, err := os.ReadFile(path) + assert.NoError(t, err) + + sketch, err := NewVarOptItemsSketchFromSlice[string](data, StringSerDe{}) + assert.NoError(t, err) + assert.Equal(t, 1024, sketch.K()) + assert.Equal(t, int64(200), sketch.N()) + assert.Equal(t, 200, sketch.NumSamples()) + + ss, err := sketch.EstimateSubsetSum(func(_ string) bool { return true }) + assert.NoError(t, err) + weight := 0.0 + for i := 1; i <= 200; i++ { + weight += 1000.0 / float64(i) + } + assert.InDelta(t, weight, ss.TotalSketchWeight, 1e-9) + }) + + t.Run("long_sampling", func(t *testing.T) { + path := filepath.Join(internal.GoPath, "varopt_sketch_long_sampling_go.sk") + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Skipf("Go file not found: %s", path) + return + } + + data, err := os.ReadFile(path) + assert.NoError(t, err) + + sketch, err := NewVarOptItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.NoError(t, err) + assert.Equal(t, 1024, sketch.K()) + assert.Equal(t, int64(2003), sketch.N()) + assert.Equal(t, 1024, sketch.NumSamples()) + + ssAll, err := sketch.EstimateSubsetSum(func(_ int64) bool { return true }) + assert.NoError(t, err) + assert.InDelta(t, 332000.0, ssAll.TotalSketchWeight, 1e-9) + + ssNeg, err := sketch.EstimateSubsetSum(func(v int64) bool { return v < 0 }) + assert.NoError(t, err) + assert.InDelta(t, 330000.0, ssNeg.Estimate, 1e-9) + + ssNonNeg, err := sketch.EstimateSubsetSum(func(v int64) bool { return v >= 0 }) + assert.NoError(t, err) + assert.InDelta(t, 2000.0, ssNonNeg.Estimate, 1e-9) + }) +} + +func TestVarOptItemsUnion_JavaCompat(t *testing.T) { + path := filepath.Join(internal.JavaPath, "varopt_union_double_sampling_java.sk") + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Skipf("Java file not found: %s", path) + return + } + data, err := os.ReadFile(path) + assert.NoError(t, err) + + union, err := NewVarOptItemsUnionFromSlice[float64](data, Float64SerDe{}) + assert.NoError(t, err) + result, err := union.Result() + assert.NoError(t, err) + assert.True(t, result.K() < 128) + assert.Equal(t, int64(97), result.N()) + + ss, err := result.EstimateSubsetSum(func(v float64) bool { return v >= 0 }) + assert.NoError(t, err) + assert.InDelta(t, 96.0, ss.Estimate, 1e-9) + assert.InDelta(t, 96.0+1024.0, ss.TotalSketchWeight, 1e-9) +} + +func TestVarOptItemsUnion_CppCompat(t *testing.T) { + path := filepath.Join(internal.CppPath, "varopt_union_double_sampling_cpp.sk") + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Skipf("CPP file not found: %s", path) + return + } + data, err := os.ReadFile(path) + assert.NoError(t, err) + + union, err := NewVarOptItemsUnionFromSlice[float64](data, Float64SerDe{}) + assert.NoError(t, err) + result, err := union.Result() + assert.NoError(t, err) + assert.True(t, result.K() < 128) + assert.True(t, result.K() >= 16) + assert.Equal(t, int64(97), result.N()) + + ss, err := result.EstimateSubsetSum(func(v float64) bool { return v >= 0 }) + assert.NoError(t, err) + assert.InDelta(t, 96.0, ss.Estimate, 1e-9) + assert.InDelta(t, 96.0+1024.0, ss.TotalSketchWeight, 1e-9) +} + +func TestVarOptItemsUnion_GoCompat(t *testing.T) { + path := filepath.Join(internal.GoPath, "varopt_union_double_sampling_go.sk") + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Skipf("Go file not found: %s", path) + return + } + data, err := os.ReadFile(path) + assert.NoError(t, err) + + union, err := NewVarOptItemsUnionFromSlice[float64](data, Float64SerDe{}) + assert.NoError(t, err) + result, err := union.Result() + assert.NoError(t, err) + assert.True(t, result.K() < 128) + assert.True(t, result.K() >= 16) + assert.Equal(t, int64(97), result.N()) + + ss, err := result.EstimateSubsetSum(func(v float64) bool { return v >= 0 }) + assert.NoError(t, err) + assert.InDelta(t, 96.0, ss.Estimate, 1e-9) + assert.InDelta(t, 96.0+1024.0, ss.TotalSketchWeight, 1e-9) +} + + // TestSerializationCompatibilityEmpty tests deserialization of an empty sketch. func TestSerializationCompatibilityEmpty(t *testing.T) { filename := filepath.Join(internal.GoPath, "reservoir_items_long_empty_k128_go.sk") diff --git a/sampling/varopt_items_serde_test.go b/sampling/varopt_items_serde_test.go new file mode 100644 index 0000000..12ce6cd --- /dev/null +++ b/sampling/varopt_items_serde_test.go @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sampling + +import ( + "encoding/binary" + "math" + "testing" + + "github.com/apache/datasketches-go/internal" + "github.com/stretchr/testify/assert" +) + +type emptyCorruptingInt64SerDe struct{} + +func (emptyCorruptingInt64SerDe) SerializeToBytes(items []int64) ([]byte, error) { + if len(items) == 0 { + return []byte{0xCA, 0xFE, 0xBA, 0xBE}, nil + } + return Int64SerDe{}.SerializeToBytes(items) +} + +func (emptyCorruptingInt64SerDe) DeserializeFromBytes(data []byte, numItems int) ([]int64, error) { + return Int64SerDe{}.DeserializeFromBytes(data, numItems) +} + +func (emptyCorruptingInt64SerDe) SizeOfItem() int { + return Int64SerDe{}.SizeOfItem() +} + +func TestVarOptItemsSketchSerde_EmptyRoundTrip(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int64](16) + assert.NoError(t, err) + + data, err := sketch.ToSlice(Int64SerDe{}) + assert.NoError(t, err) + + restored, err := NewVarOptItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.NoError(t, err) + assert.True(t, restored.IsEmpty()) + assert.Equal(t, 16, restored.K()) +} + +func TestVarOptItemsSketchSerde_EmptySketchIgnoresCustomEmptyItemsBytes(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int64](16) + assert.NoError(t, err) + + data, err := sketch.ToSlice(emptyCorruptingInt64SerDe{}) + assert.NoError(t, err) + assert.Equal(t, 8, len(data)) + + restored, err := NewVarOptItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.NoError(t, err) + assert.True(t, restored.IsEmpty()) + assert.Equal(t, 16, restored.K()) + assert.Equal(t, int64(0), restored.N()) +} + +func TestVarOptItemsSketchSerde_WarmupRoundTrip(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int64](16) + assert.NoError(t, err) + for i := int64(1); i <= 10; i++ { + assert.NoError(t, sketch.Update(i, float64(i))) + } + + data, err := sketch.ToSlice(Int64SerDe{}) + assert.NoError(t, err) + + restored, err := NewVarOptItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.NoError(t, err) + assert.Equal(t, sketch.K(), restored.K()) + assert.Equal(t, sketch.N(), restored.N()) + assert.Equal(t, sketch.H(), restored.H()) + assert.Equal(t, sketch.R(), restored.R()) + assert.Greater(t, cap(restored.data), restored.H()) + assert.Equal(t, cap(restored.data), cap(restored.weights)) +} + +func TestVarOptItemsSketchSerde_SamplingRoundTrip(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int64](16) + assert.NoError(t, err) + for i := int64(1); i <= 80; i++ { + assert.NoError(t, sketch.Update(i, float64(i))) + } + assert.Greater(t, sketch.R(), 0) + + data, err := sketch.ToSlice(Int64SerDe{}) + assert.NoError(t, err) + + restored, err := NewVarOptItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.NoError(t, err) + assert.Equal(t, sketch.K(), restored.K()) + assert.Equal(t, sketch.N(), restored.N()) + assert.Equal(t, sketch.H(), restored.H()) + assert.Equal(t, sketch.R(), restored.R()) + assert.InDelta(t, sketch.totalWeightR, restored.totalWeightR, 1e-9) +} + +func TestVarOptItemsUnionSerde_EmptyRoundTrip(t *testing.T) { + union, err := NewVarOptItemsUnion[int64](16) + assert.NoError(t, err) + + data, err := union.ToSlice(Int64SerDe{}) + assert.NoError(t, err) + + restored, err := NewVarOptItemsUnionFromSlice[int64](data, Int64SerDe{}) + assert.NoError(t, err) + assert.Equal(t, union.maxK, restored.maxK) + + result, err := restored.Result() + assert.NoError(t, err) + assert.True(t, result.IsEmpty()) +} + +func TestVarOptItemsUnionSerde_NonEmptyRoundTrip(t *testing.T) { + union, err := NewVarOptItemsUnion[int64](16) + assert.NoError(t, err) + + sketch, err := NewVarOptItemsSketch[int64](16) + assert.NoError(t, err) + for i := int64(1); i <= 10; i++ { + assert.NoError(t, sketch.Update(i, float64(i))) + } + assert.NoError(t, union.UpdateSketch(sketch)) + + data, err := union.ToSlice(Int64SerDe{}) + assert.NoError(t, err) + + restored, err := NewVarOptItemsUnionFromSlice[int64](data, Int64SerDe{}) + assert.NoError(t, err) + assert.Equal(t, union.maxK, restored.maxK) + + result, err := restored.Result() + assert.NoError(t, err) + assert.Equal(t, sketch.N(), result.N()) + assert.Equal(t, sketch.NumSamples(), result.NumSamples()) +} + +func TestVarOptItemsSketchSerde_HeaderConsistency(t *testing.T) { + // preLongs says empty, but empty flag is not set. + data := make([]byte, 8) + data[0] = byte(varOptPreambleLongsEmpty) + data[1] = varOptSerVer + data[2] = byte(internal.FamilyEnum.VarOptItems.Id) + data[3] = 0 + binary.LittleEndian.PutUint32(data[4:], 8) + + _, err := NewVarOptItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.ErrorContains(t, err, "empty preLongs without empty flag") +} + +func TestVarOptItemsSketchSerde_WarmupDataWithFullPreLongsIsInvalid(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int64](16) + assert.NoError(t, err) + for i := int64(1); i <= 10; i++ { + assert.NoError(t, sketch.Update(i, float64(i))) + } + assert.Equal(t, 0, sketch.R()) + + data, err := sketch.ToSlice(Int64SerDe{}) + assert.NoError(t, err) + + data[0] = (data[0] & 0xC0) | byte(varOptPreambleLongsFull) + + _, err = NewVarOptItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.ErrorContains(t, err, "n <= k but not in warmup mode") +} + +func TestVarOptItemsSketchSerde_WarmupModeRequiresNEqualsH(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int64](16) + assert.NoError(t, err) + for i := int64(1); i <= 10; i++ { + assert.NoError(t, sketch.Update(i, float64(i))) + } + + data, err := sketch.ToSlice(Int64SerDe{}) + assert.NoError(t, err) + + binary.LittleEndian.PutUint64(data[8:], uint64(9)) + + _, err = NewVarOptItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.ErrorContains(t, err, "warmup mode but n != h") +} + +func TestVarOptItemsSketchSerde_WarmupModeRequiresRZero(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int64](16) + assert.NoError(t, err) + for i := int64(1); i <= 10; i++ { + assert.NoError(t, sketch.Update(i, float64(i))) + } + + data, err := sketch.ToSlice(Int64SerDe{}) + assert.NoError(t, err) + + binary.LittleEndian.PutUint32(data[20:], uint32(1)) + + _, err = NewVarOptItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.ErrorContains(t, err, "warmup mode but r > 0") +} + +func TestVarOptItemsSketchSerde_FullModeRequiresHSumREqualsK(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int64](16) + assert.NoError(t, err) + for i := int64(1); i <= 80; i++ { + assert.NoError(t, sketch.Update(i, float64(i))) + } + assert.Greater(t, sketch.R(), 0) + + data, err := sketch.ToSlice(Int64SerDe{}) + assert.NoError(t, err) + + h := binary.LittleEndian.Uint32(data[16:]) + binary.LittleEndian.PutUint32(data[16:], h-1) + + _, err = NewVarOptItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.ErrorContains(t, err, "full mode but h + r != k") +} + +func TestVarOptItemsSketchSerde_NGreaterThanKRequiresFullMode(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int64](16) + assert.NoError(t, err) + for i := int64(1); i <= 80; i++ { + assert.NoError(t, sketch.Update(i, float64(i))) + } + assert.Greater(t, sketch.N(), int64(sketch.K())) + + data, err := sketch.ToSlice(Int64SerDe{}) + assert.NoError(t, err) + + data[0] = (data[0] & 0xC0) | byte(varOptPreambleLongsWarmup) + + _, err = NewVarOptItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.ErrorContains(t, err, "n > k but not in full mode") +} + +func TestVarOptItemsSketchSerde_FullModeRequiresRPositive(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int64](16) + assert.NoError(t, err) + for i := int64(1); i <= 80; i++ { + assert.NoError(t, sketch.Update(i, float64(i))) + } + assert.Greater(t, sketch.R(), 0) + + data, err := sketch.ToSlice(Int64SerDe{}) + assert.NoError(t, err) + + binary.LittleEndian.PutUint32(data[16:], uint32(sketch.K())) + binary.LittleEndian.PutUint32(data[20:], uint32(0)) + + _, err = NewVarOptItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.ErrorContains(t, err, "full mode but r == 0") +} + +func TestVarOptItemsSketchSerde_NaNTotalWeightRIsInvalid(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int64](16) + assert.NoError(t, err) + for i := int64(1); i <= 80; i++ { + assert.NoError(t, sketch.Update(i, float64(i))) + } + assert.Greater(t, sketch.R(), 0) + + data, err := sketch.ToSlice(Int64SerDe{}) + assert.NoError(t, err) + + binary.LittleEndian.PutUint64(data[24:], math.Float64bits(math.NaN())) + + _, err = NewVarOptItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.ErrorContains(t, err, "invalid totalWeightR") +} + +func TestVarOptItemsUnionSerde_HeaderConsistency(t *testing.T) { + // preLongs says empty, but empty flag is not set. + data := make([]byte, 8) + data[0] = byte(varOptUnionPreambleLongsEmpty) + data[1] = varOptUnionSerVer + data[2] = byte(internal.FamilyEnum.VarOptUnion.Id) + data[3] = 0 + binary.LittleEndian.PutUint32(data[4:], 8) + + _, err := NewVarOptItemsUnionFromSlice[int64](data, Int64SerDe{}) + assert.ErrorContains(t, err, "empty preLongs without empty flag") +} diff --git a/sampling/varopt_items_sketch_serde.go b/sampling/varopt_items_sketch_serde.go new file mode 100644 index 0000000..6e1d222 --- /dev/null +++ b/sampling/varopt_items_sketch_serde.go @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sampling + +import ( + "encoding/binary" + "errors" + "math" + + "github.com/apache/datasketches-go/internal" +) + +const ( + varOptPreambleLongsEmpty = 1 + varOptPreambleLongsWarmup = 3 + varOptPreambleLongsFull = 4 + + varOptSerVer = 2 + varOptFlagEmpty = 0x04 + varOptFlagGadget = 0x80 +) + +// ToSlice serializes the sketch to a byte slice using Java/C++ compatible preamble layout. +func (s *VarOptItemsSketch[T]) ToSlice(serde ItemsSerDe[T]) ([]byte, error) { + rfBits, err := resizeFactorBitsFor(s.rf) + if err != nil { + return nil, err + } + + flags := byte(0) + if s.marks != nil { + flags |= varOptFlagGadget + } + + preLongs := varOptPreambleLongsEmpty + totalItems := 0 + if s.IsEmpty() { + flags |= varOptFlagEmpty + } else { + totalItems = s.h + s.r + if s.r == 0 { + preLongs = varOptPreambleLongsWarmup + } else { + preLongs = varOptPreambleLongsFull + } + } + + weightsBytes := s.h * 8 + markBytes := 0 + if s.marks != nil { + markBytes = packedBoolBytes(s.h) + } + + var items []T + if totalItems > 0 { + items = make([]T, 0, totalItems) + for i := 0; i < s.h; i++ { + items = append(items, s.data[i]) + } + for i := s.h + 1; i <= s.k && s.r > 0; i++ { + items = append(items, s.data[i]) + } + } + + itemsBytes := []byte(nil) + if totalItems > 0 { + itemsBytes, err = serde.SerializeToBytes(items) + if err != nil { + return nil, err + } + } + + preambleBytes := preLongs * 8 + out := make([]byte, preambleBytes+weightsBytes+markBytes+len(itemsBytes)) + + out[0] = rfBits | byte(preLongs) + out[1] = varOptSerVer + out[2] = byte(internal.FamilyEnum.VarOptItems.Id) + out[3] = flags + binary.LittleEndian.PutUint32(out[4:], uint32(s.k)) + + if !s.IsEmpty() { + binary.LittleEndian.PutUint64(out[8:], uint64(s.n)) + binary.LittleEndian.PutUint32(out[16:], uint32(s.h)) + binary.LittleEndian.PutUint32(out[20:], uint32(s.r)) + if s.r > 0 { + binary.LittleEndian.PutUint64(out[24:], math.Float64bits(s.totalWeightR)) + } + } + + weightOffset := preambleBytes + if !s.IsEmpty() { + weightOffset = 24 + if s.r > 0 { + weightOffset += 8 + } + } + for i := 0; i < s.h; i++ { + binary.LittleEndian.PutUint64(out[weightOffset+i*8:], math.Float64bits(s.weights[i])) + } + + markOffset := weightOffset + weightsBytes + if s.marks != nil && s.h > 0 { + packBoolsInto(out[markOffset:markOffset+markBytes], s.marks[:s.h]) + } + + if totalItems > 0 { + copy(out[markOffset+markBytes:], itemsBytes) + } + return out, nil +} + +// NewVarOptItemsSketchFromSlice deserializes a sketch from bytes. +func NewVarOptItemsSketchFromSlice[T any](data []byte, serde ItemsSerDe[T]) (*VarOptItemsSketch[T], error) { + if len(data) < 8 { + return nil, errors.New("data too short") + } + + preLongs := int(data[0] & 0x3F) + rf, err := resizeFactorFromHeaderByte(data[0]) + if err != nil { + return nil, err + } + ver := data[1] + family := data[2] + flags := data[3] + k := int(binary.LittleEndian.Uint32(data[4:])) + + if ver != varOptSerVer { + return nil, errors.New("unsupported serialization version") + } + if family != byte(internal.FamilyEnum.VarOptItems.Id) { + return nil, errors.New("wrong sketch family") + } + if k < 1 || k > varOptMaxK { + return nil, errors.New("invalid k in serialized varopt sketch") + } + + hasEmptyFlag := (flags & varOptFlagEmpty) != 0 + if preLongs == varOptPreambleLongsEmpty && !hasEmptyFlag { + return nil, errors.New("invalid varopt sketch header: empty preLongs without empty flag") + } + if preLongs != varOptPreambleLongsEmpty && hasEmptyFlag { + return nil, errors.New("invalid varopt sketch header: non-empty preLongs with empty flag") + } + + isEmpty := hasEmptyFlag + isGadget := (flags & varOptFlagGadget) != 0 + if isEmpty { + out, err := NewVarOptItemsSketch[T](uint(k), WithResizeFactor(rf)) + if err != nil { + return nil, err + } + if isGadget { + out.marks = make([]bool, 0, cap(out.data)) + } + return out, nil + } + + if preLongs != varOptPreambleLongsWarmup && preLongs != varOptPreambleLongsFull { + return nil, errors.New("invalid preLongs for non-empty varopt sketch") + } + if len(data) < preLongs*8 { + return nil, errors.New("data too short for varopt preamble") + } + + n := int64(binary.LittleEndian.Uint64(data[8:])) + h := int(binary.LittleEndian.Uint32(data[16:])) + r := int(binary.LittleEndian.Uint32(data[20:])) + if h < 0 || r < 0 { + return nil, errors.New("invalid h/r in serialized varopt sketch") + } + if n < 0 { + return nil, errors.New("invalid n in serialized varopt sketch: negative") + } + + if n <= int64(k) { + if preLongs != varOptPreambleLongsWarmup { + return nil, errors.New("invalid varopt sketch state: n <= k but not in warmup mode") + } + if int64(h) != n { + return nil, errors.New("invalid varopt sketch state: warmup mode but n != h") + } + if r > 0 { + return nil, errors.New("invalid varopt sketch state: warmup mode but r > 0") + } + } else { + if preLongs != varOptPreambleLongsFull { + return nil, errors.New("invalid varopt sketch state: n > k but not in full mode") + } + if h+r != k { + return nil, errors.New("invalid varopt sketch state: full mode but h + r != k") + } + if r == 0 { + return nil, errors.New("invalid varopt sketch state: full mode but r == 0") + } + } + + totalWeightR := 0.0 + if r > 0 { + totalWeightR = math.Float64frombits(binary.LittleEndian.Uint64(data[24:])) + if math.IsNaN(totalWeightR) || math.IsInf(totalWeightR, 0) || totalWeightR <= 0 { + return nil, errors.New("invalid totalWeightR in serialized varopt sketch: non-positive or non-finite") + } + } + + weightOffset := 24 + if r > 0 { + weightOffset += 8 + } + weightsBytes := h * 8 + if len(data) < weightOffset+weightsBytes { + return nil, errors.New("data too short for varopt weights") + } + + hWeights := make([]float64, h) + for i := 0; i < h; i++ { + w := math.Float64frombits(binary.LittleEndian.Uint64(data[weightOffset+i*8:])) + if w <= 0 || math.IsNaN(w) || math.IsInf(w, 0) { + return nil, errors.New("invalid non-positive or non-finite weight in serialized varopt sketch") + } + hWeights[i] = w + } + + markOffset := weightOffset + weightsBytes + hMarks := make([]bool, h) + numMarksInH := uint32(0) + if isGadget && h > 0 { + markBytes := packedBoolBytes(h) + if len(data) < markOffset+markBytes { + return nil, errors.New("data too short for varopt marks") + } + unpackBoolsFrom(data[markOffset:markOffset+markBytes], hMarks) + for _, m := range hMarks { + if m { + numMarksInH++ + } + } + markOffset += markBytes + } + + totalItems := h + r + items, err := serde.DeserializeFromBytes(data[markOffset:], totalItems) + if err != nil { + return nil, err + } + + if r == 0 { + ceilingLgK := math.Log2(float64(internal.CeilPowerOf2(k))) + initialLgSize := startingSubMultiple(int(ceilingLgK), int(rf), minLgArrItems) + warmupCap := adjustedSamplingAllocationSize(k, 1< 0 { + tau := sketch.tau() + cumWeight := 0.0 + rSeen := 0 + for i := sketch.h + 1; i <= sketch.k; i++ { + w := tau + // Match Java/C++ weight-correcting iterator semantics: + // correct the last R item to absorb floating-point residual. + if rSeen == sketch.r-1 { + w = sketch.totalWeightR - cumWeight + } else { + cumWeight += tau + } + rSeen++ + if err := u.gadget.update(sketch.data[i], w, true); err != nil { + return err + } + } + u.resolveOuterTau(sketch) + } + + return nil +} + +// Result returns the current union result sketch. +// +// If marked items remain in H, full resolution logic is required and is implemented +// in the next step. For now we fail fast with a clear error. +func (u *VarOptItemsUnion[T]) Result() (*VarOptItemsSketch[T], error) { + if u.gadget == nil || u.gadget.N() == 0 { + return NewVarOptItemsSketch[T](uint(u.maxK)) + } + + if u.gadget.numMarksInH == 0 { + out := copyVarOptItemsSketch(u.gadget, true) + out.n = u.n + return out, nil + } + + // Marked items in H require the full resolution path. + if out, ok, err := u.detectAndHandleSubcaseOfPseudoExact(); err != nil { + return nil, err + } else if ok { + return out, nil + } + return u.migrateMarkedItemsByDecreasingK() +} + +func (u *VarOptItemsUnion[T]) resolveOuterTau(sketch *VarOptItemsSketch[T]) { + if sketch.r == 0 { + return + } + + sketchTau := sketch.tau() + if u.outerTauDenom == 0 { + u.outerTauNumer = sketch.totalWeightR + u.outerTauDenom = int64(sketch.r) + return + } + + outerTau := u.outerTauNumer / float64(u.outerTauDenom) + if sketchTau > outerTau { + u.outerTauNumer = sketch.totalWeightR + u.outerTauDenom = int64(sketch.r) + return + } + if sketchTau == outerTau { + u.outerTauNumer += sketch.totalWeightR + u.outerTauDenom += int64(sketch.r) + } +} + +func newVarOptItemsSketchAsGadget[T any](k int) (*VarOptItemsSketch[T], error) { + sketch, err := NewVarOptItemsSketch[T](uint(k)) + if err != nil { + return nil, err + } + sketch.marks = make([]bool, 0, cap(sketch.data)) + return sketch, nil +} + +func (u *VarOptItemsUnion[T]) detectAndHandleSubcaseOfPseudoExact() (*VarOptItemsSketch[T], bool, error) { + condition1 := u.gadget.r == 0 + condition2 := u.gadget.numMarksInH > 0 + condition3 := int64(u.gadget.numMarksInH) == u.outerTauDenom + + if !(condition1 && condition2 && condition3) { + return nil, false, nil + } + + if u.thereExistUnmarkedHItemsLighterThanTarget(u.gadget.tau()) { + return nil, false, nil + } + + out, err := u.markMovingGadgetCoercer() + if err != nil { + return nil, false, err + } + return out, true, nil +} + +func (u *VarOptItemsUnion[T]) thereExistUnmarkedHItemsLighterThanTarget(threshold float64) bool { + for i := 0; i < u.gadget.h; i++ { + if u.gadget.weights[i] < threshold && !u.gadget.marks[i] { + return true + } + } + return false +} + +func (u *VarOptItemsUnion[T]) markMovingGadgetCoercer() (*VarOptItemsSketch[T], error) { + resultK := u.gadget.h + u.gadget.r + resultH := 0 + resultR := 0 + nextRPos := resultK + + data := make([]T, resultK+1) + weights := make([]float64, resultK+1) + + // Move existing R region items first (weight remains implicit via totalWeightR). + for i := u.gadget.h + 1; i <= u.gadget.k && i < len(u.gadget.data); i++ { + data[nextRPos] = u.gadget.data[i] + weights[nextRPos] = -1.0 + resultR++ + nextRPos-- + } + + transferredWeight := 0.0 + for i := 0; i < u.gadget.h; i++ { + if u.gadget.marks[i] { + data[nextRPos] = u.gadget.data[i] + weights[nextRPos] = -1.0 + transferredWeight += u.gadget.weights[i] + resultR++ + nextRPos-- + } else { + data[resultH] = u.gadget.data[i] + weights[resultH] = u.gadget.weights[i] + resultH++ + } + } + + if resultH+resultR != resultK { + return nil, errors.New("invalid state resolving pseudo-exact union gadget") + } + if math.Abs(transferredWeight-u.outerTauNumer) > 1e-10 { + return nil, errors.New("unexpected mismatch in transferred weight") + } + + // Gap slot. + weights[resultH] = -1.0 + + out := &VarOptItemsSketch[T]{ + data: data, + weights: weights, + k: resultK, + n: u.n, + h: resultH, + m: 0, + r: resultR, + totalWeightR: u.gadget.totalWeightR + transferredWeight, + rf: varOptDefaultResizeFactor, + numMarksInH: 0, + } + + if err := out.heapify(); err != nil { + return nil, err + } + return out, nil +} + +func (u *VarOptItemsUnion[T]) migrateMarkedItemsByDecreasingK() (*VarOptItemsSketch[T], error) { + gcopy := copyVarOptItemsSketch(u.gadget, false) + gcopy.n = u.n + + rCount := gcopy.r + hCount := gcopy.h + k := gcopy.k + + // If non-full and pseudo-exact, set k to sample count so reductions increase tau. + if rCount == 0 && hCount < k { + gcopy.k = hCount + } + + if gcopy.k < 2 { + return nil, errors.New("cannot resolve marked items with k < 2") + } + if err := decreaseKBy1(gcopy); err != nil { + return nil, err + } + + for gcopy.numMarksInH > 0 { + if gcopy.k < 2 { + return nil, errors.New("cannot continue resolving marked items with k < 2") + } + if err := decreaseKBy1(gcopy); err != nil { + return nil, err + } + } + + gcopy.numMarksInH = 0 + gcopy.marks = nil + return gcopy, nil +} + +func decreaseKBy1[T any](s *VarOptItemsSketch[T]) error { + if s.k <= 1 { + return errors.New("cannot decrease k below 1 in union") + } + + switch { + case s.h == 0 && s.r == 0: + s.k-- + return nil + case s.h > 0 && s.r == 0: + s.k-- + if s.h > s.k { + return s.transitionFromWarmup() + } + return nil + case s.h > 0 && s.r > 0: + oldGapIdx := s.h + oldFinalRIdx := (s.h + 1 + s.r) - 1 + s.swap(oldFinalRIdx, oldGapIdx) + + pulledIdx := s.h - 1 + pulledItem := s.data[pulledIdx] + pulledWeight := s.weights[pulledIdx] + pulledMark := s.marks[pulledIdx] + + if pulledMark { + s.numMarksInH-- + } + s.weights[pulledIdx] = -1.0 + + s.h-- + s.k-- + s.n-- + return s.update(pulledItem, pulledWeight, pulledMark) + case s.h == 0 && s.r > 0: + if s.r < 2 { + return errors.New("invalid pure-reservoir state while decreasing k") + } + rIdxToDelete := 1 + rand.Intn(s.r) + rightmostRIdx := (1 + s.r) - 1 + s.swap(rIdxToDelete, rightmostRIdx) + s.weights[rightmostRIdx] = -1.0 + + s.k-- + s.r-- + return nil + default: + return errors.New("invalid sketch state while decreasing k") + } +} + +func copyVarOptItemsSketch[T any](in *VarOptItemsSketch[T], asSketch bool) *VarOptItemsSketch[T] { + dataCopy := make([]T, len(in.data)) + copy(dataCopy, in.data) + + weightsCopy := make([]float64, len(in.weights)) + copy(weightsCopy, in.weights) + + var marksCopy []bool + numMarksInH := in.numMarksInH + if !asSketch && in.marks != nil { + marksCopy = make([]bool, len(in.marks)) + copy(marksCopy, in.marks) + } else { + numMarksInH = 0 + } + + return &VarOptItemsSketch[T]{ + data: dataCopy, + weights: weightsCopy, + marks: marksCopy, + k: in.k, + n: in.n, + h: in.h, + m: in.m, + r: in.r, + totalWeightR: in.totalWeightR, + rf: in.rf, + numMarksInH: numMarksInH, + } +} + +// ToSlice serializes the union state to bytes. +func (u *VarOptItemsUnion[T]) ToSlice(serde ItemsSerDe[T]) ([]byte, error) { + empty := u.gadget == nil || u.gadget.N() == 0 + if empty { + out := make([]byte, varOptUnionPreambleLongsEmpty*8) + out[0] = byte(varOptUnionPreambleLongsEmpty) + out[1] = varOptUnionSerVer + out[2] = byte(internal.FamilyEnum.VarOptUnion.Id) + out[3] = varOptUnionFlagEmpty + binary.LittleEndian.PutUint32(out[4:], uint32(u.maxK)) + return out, nil + } + + gadgetBytes, err := u.gadget.ToSlice(serde) + if err != nil { + return nil, err + } + + preBytes := varOptUnionPreambleLongsNonEmpty * 8 + out := make([]byte, preBytes+len(gadgetBytes)) + out[0] = byte(varOptUnionPreambleLongsNonEmpty) + out[1] = varOptUnionSerVer + out[2] = byte(internal.FamilyEnum.VarOptUnion.Id) + out[3] = 0 + binary.LittleEndian.PutUint32(out[4:], uint32(u.maxK)) + binary.LittleEndian.PutUint64(out[8:], uint64(u.n)) + binary.LittleEndian.PutUint64(out[16:], math.Float64bits(u.outerTauNumer)) + binary.LittleEndian.PutUint64(out[24:], uint64(u.outerTauDenom)) + copy(out[preBytes:], gadgetBytes) + return out, nil +} + +// NewVarOptItemsUnionFromSlice deserializes union state from bytes. +func NewVarOptItemsUnionFromSlice[T any](data []byte, serde ItemsSerDe[T]) (*VarOptItemsUnion[T], error) { + if len(data) < 8 { + return nil, errors.New("data too short") + } + preLongs := int(data[0] & 0x3F) + ver := data[1] + family := data[2] + flags := data[3] + maxK := int(binary.LittleEndian.Uint32(data[4:])) + + if ver != varOptUnionSerVer { + return nil, errors.New("unsupported serialization version") + } + if family != byte(internal.FamilyEnum.VarOptUnion.Id) { + return nil, errors.New("wrong sketch family") + } + if preLongs != varOptUnionPreambleLongsEmpty && preLongs != varOptUnionPreambleLongsNonEmpty { + return nil, errors.New("invalid preLongs for varopt union") + } + + union, err := NewVarOptItemsUnion[T](maxK) + if err != nil { + return nil, err + } + + hasEmptyFlag := (flags & varOptUnionFlagEmpty) != 0 + if preLongs == varOptUnionPreambleLongsEmpty && !hasEmptyFlag { + return nil, errors.New("invalid varopt union header: empty preLongs without empty flag") + } + if preLongs != varOptUnionPreambleLongsEmpty && hasEmptyFlag { + return nil, errors.New("invalid varopt union header: non-empty preLongs with empty flag") + } + + if hasEmptyFlag { + return union, nil + } + + if len(data) < varOptUnionPreambleLongsNonEmpty*8 { + return nil, errors.New("data too short for non-empty varopt union") + } + + union.n = int64(binary.LittleEndian.Uint64(data[8:])) + union.outerTauNumer = math.Float64frombits(binary.LittleEndian.Uint64(data[16:])) + union.outerTauDenom = int64(binary.LittleEndian.Uint64(data[24:])) + + gadget, err := NewVarOptItemsSketchFromSlice[T](data[varOptUnionPreambleLongsNonEmpty*8:], serde) + if err != nil { + return nil, err + } + union.gadget = gadget + return union, nil +} diff --git a/sampling/varopt_items_union_preamble.go b/sampling/varopt_items_union_preamble.go new file mode 100644 index 0000000..ae2f1eb --- /dev/null +++ b/sampling/varopt_items_union_preamble.go @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sampling + +// VarOpt union serialization constants aligned with Java/C++ preamble definitions. +const ( + varOptUnionPreambleLongsEmpty = 1 + varOptUnionPreambleLongsNonEmpty = 4 + varOptUnionSerVer = 2 + varOptUnionFlagEmpty = 0x04 +) + diff --git a/sampling/varopt_items_union_test.go b/sampling/varopt_items_union_test.go new file mode 100644 index 0000000..190be8c --- /dev/null +++ b/sampling/varopt_items_union_test.go @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sampling + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewVarOptItemsUnion(t *testing.T) { + _, err := NewVarOptItemsUnion[int](0) + assert.ErrorContains(t, err, "k must be at least 1") + + union, err := NewVarOptItemsUnion[int](16) + assert.NoError(t, err) + assert.Equal(t, 16, union.maxK) +} + +func TestVarOptItemsUnion_ResultEmpty(t *testing.T) { + union, err := NewVarOptItemsUnion[int](8) + assert.NoError(t, err) + + result, err := union.Result() + assert.NoError(t, err) + assert.Equal(t, 8, result.K()) + assert.Equal(t, int64(0), result.N()) + assert.True(t, result.IsEmpty()) +} + +func TestVarOptItemsUnion_UpdateSketchExactMode(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int](16) + assert.NoError(t, err) + + for i := 0; i < 8; i++ { + assert.NoError(t, sketch.Update(i, float64(i+1))) + } + + union, err := NewVarOptItemsUnion[int](16) + assert.NoError(t, err) + assert.NoError(t, union.UpdateSketch(sketch)) + + result, err := union.Result() + assert.NoError(t, err) + assert.Equal(t, int64(8), result.N()) + assert.Equal(t, 8, result.NumSamples()) + assert.Equal(t, 8, result.H()) + assert.Equal(t, 0, result.R()) +} + +func TestVarOptItemsUnion_UpdateSketchSamplingWithExtremeHeavyItem(t *testing.T) { + const k = 16 + + sketch1, err := NewVarOptItemsSketch[int](uint(k)) + assert.NoError(t, err) + for i := 0; i < 500; i++ { + assert.NoError(t, sketch1.Update(i, 1.0)) + } + assert.NoError(t, sketch1.Update(-1, 1e12)) + assert.Greater(t, sketch1.R(), 0) + + sketch2, err := NewVarOptItemsSketch[int](uint(k)) + assert.NoError(t, err) + for i := 1000; i < 1500; i++ { + assert.NoError(t, sketch2.Update(i, 1.0)) + } + assert.Greater(t, sketch2.R(), 0) + + union, err := NewVarOptItemsUnion[int](k) + assert.NoError(t, err) + assert.NoError(t, union.UpdateSketch(sketch1)) + assert.NoError(t, union.UpdateSketch(sketch2)) + + result, err := union.Result() + assert.NoError(t, err) + assert.Equal(t, int64(1001), result.N()) + assert.Equal(t, k, result.K()) + assert.LessOrEqual(t, result.NumSamples(), k) + + foundHeavy := false + for sample := range result.All() { + if sample.Item == -1 { + foundHeavy = true + break + } + } + assert.True(t, foundHeavy, "extreme heavy item should be retained in union result") +} + +func TestVarOptItemsUnion_UpdateSketchIdenticalSamplingSketches(t *testing.T) { + const k = 16 + const n = 1000 + + base, err := NewVarOptItemsSketch[int64](uint(k)) + assert.NoError(t, err) + for i := 0; i < n; i++ { + assert.NoError(t, base.Update(int64(i), 1.0)) + } + assert.Greater(t, base.R(), 0) + + data, err := base.ToSlice(Int64SerDe{}) + assert.NoError(t, err) + clone, err := NewVarOptItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.NoError(t, err) + + union, err := NewVarOptItemsUnion[int64](k) + assert.NoError(t, err) + assert.NoError(t, union.UpdateSketch(base)) + assert.NoError(t, union.UpdateSketch(clone)) + + result, err := union.Result() + assert.NoError(t, err) + assert.Equal(t, int64(2*n), result.N()) + assert.Equal(t, k, result.K()) + assert.LessOrEqual(t, result.NumSamples(), k) + + ss, err := result.EstimateSubsetSum(func(_ int64) bool { return true }) + assert.NoError(t, err) + assert.InDelta(t, float64(2*n), ss.TotalSketchWeight, 1e-9) +} + +func TestVarOptItemsUnion_UpdateSketchDifferentKWeightedItems(t *testing.T) { + smallK := 8 + largeK := 32 + + small, err := NewVarOptItemsSketch[int](uint(smallK)) + assert.NoError(t, err) + totalWeight := 0.0 + for i := 1; i <= 200; i++ { + w := float64(i) + totalWeight += w + assert.NoError(t, small.Update(i, w)) + } + assert.Greater(t, small.R(), 0) + + large, err := NewVarOptItemsSketch[int](uint(largeK)) + assert.NoError(t, err) + for i := 1; i <= 400; i++ { + w := float64(i) * 0.5 + totalWeight += w + assert.NoError(t, large.Update(10000+i, w)) + } + assert.Greater(t, large.R(), 0) + + union, err := NewVarOptItemsUnion[int](largeK) + assert.NoError(t, err) + assert.NoError(t, union.UpdateSketch(small)) + assert.NoError(t, union.UpdateSketch(large)) + + result, err := union.Result() + assert.NoError(t, err) + assert.Equal(t, int64(600), result.N()) + assert.GreaterOrEqual(t, result.K(), 1) + assert.LessOrEqual(t, result.K(), largeK) + assert.LessOrEqual(t, result.NumSamples(), largeK) + + ss, err := result.EstimateSubsetSum(func(_ int) bool { return true }) + assert.NoError(t, err) + assert.InDelta(t, totalWeight, ss.TotalSketchWeight, 1e-9) +} + +func TestVarOptItemsUnion_UpdateSketchNilNoop(t *testing.T) { + union, err := NewVarOptItemsUnion[int](8) + assert.NoError(t, err) + + assert.NoError(t, union.UpdateSketch(nil)) + + result, err := union.Result() + assert.NoError(t, err) + assert.Equal(t, int64(0), result.N()) +} + +func TestVarOptItemsUnion_Reset(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int](8) + assert.NoError(t, err) + for i := 0; i < 4; i++ { + assert.NoError(t, sketch.Update(i, float64(i+1))) + } + + union, err := NewVarOptItemsUnion[int](8) + assert.NoError(t, err) + assert.NoError(t, union.UpdateSketch(sketch)) + + assert.NoError(t, union.Reset()) + + result, err := union.Result() + assert.NoError(t, err) + assert.True(t, result.IsEmpty()) + assert.Equal(t, 8, result.K()) +} + +func TestVarOptItemsUnion_ResultPseudoExactMarkedResolution(t *testing.T) { + union, err := NewVarOptItemsUnion[int](8) + assert.NoError(t, err) + + // Construct a pseudo-exact gadget: r=0 with marked items in H. + for i := 1; i <= 4; i++ { + assert.NoError(t, union.gadget.update(i, float64(i), true)) + } + union.n = 4 + union.outerTauDenom = int64(union.gadget.numMarksInH) + union.outerTauNumer = 10.0 + + out, err := union.Result() + assert.NoError(t, err) + assert.NotNil(t, out) + assert.Equal(t, int64(4), out.N()) + assert.Equal(t, 4, out.K()) + assert.Equal(t, 0, out.H()) + assert.Equal(t, 4, out.R()) + assert.Nil(t, out.marks) + assert.Equal(t, uint32(0), out.numMarksInH) +} + +func TestVarOptItemsUnion_ResultPseudoExactTransferredWeightMismatch(t *testing.T) { + union, err := NewVarOptItemsUnion[int](8) + assert.NoError(t, err) + + // Construct a pseudo-exact gadget where transferred marked-H weight is known. + assert.NoError(t, union.gadget.update(1, 1.0, true)) + assert.NoError(t, union.gadget.update(2, 2.0, true)) + assert.NoError(t, union.gadget.update(3, 3.0, true)) + union.n = 3 + union.outerTauDenom = int64(union.gadget.numMarksInH) + + // Intentionally break bookkeeping to ensure we fail fast like Java/C++ checks. + union.outerTauNumer = 123.456 + + out, err := union.Result() + assert.Nil(t, out) + assert.ErrorContains(t, err, "unexpected mismatch in transferred weight") +} + +func TestVarOptItemsUnion_ResultMigrateMarkedItemsByDecreasingK(t *testing.T) { + union, err := NewVarOptItemsUnion[int](8) + assert.NoError(t, err) + + // Construct a compact, valid estimation-mode gadget with one marked item in H. + // Layout: [H=0] [gap=1] [R=2], with k=2, h=1, r=1. + union.gadget = &VarOptItemsSketch[int]{ + data: []int{100, 0, 1}, + weights: []float64{10.0, -1.0, -1.0}, + marks: []bool{true, false, false}, + k: 2, + n: 10, + h: 1, + m: 0, + r: 1, + totalWeightR: 5.0, + rf: varOptDefaultResizeFactor, + numMarksInH: 1, + } + union.n = 10 + + out, err := union.Result() + assert.NoError(t, err) + assert.NotNil(t, out) + assert.Equal(t, int64(10), out.N()) + assert.Nil(t, out.marks) + assert.Equal(t, uint32(0), out.numMarksInH) +} diff --git a/serialization_test_data/cpp_generated_files/varopt_sketch_long_n0_cpp.sk b/serialization_test_data/cpp_generated_files/varopt_sketch_long_n0_cpp.sk new file mode 100644 index 0000000..e4505fe Binary files /dev/null and b/serialization_test_data/cpp_generated_files/varopt_sketch_long_n0_cpp.sk differ diff --git a/serialization_test_data/cpp_generated_files/varopt_sketch_long_n1000000_cpp.sk b/serialization_test_data/cpp_generated_files/varopt_sketch_long_n1000000_cpp.sk new file mode 100644 index 0000000..44aecca Binary files /dev/null and b/serialization_test_data/cpp_generated_files/varopt_sketch_long_n1000000_cpp.sk differ diff --git a/serialization_test_data/cpp_generated_files/varopt_sketch_long_n100000_cpp.sk b/serialization_test_data/cpp_generated_files/varopt_sketch_long_n100000_cpp.sk new file mode 100644 index 0000000..7f496fd Binary files /dev/null and b/serialization_test_data/cpp_generated_files/varopt_sketch_long_n100000_cpp.sk differ diff --git a/serialization_test_data/cpp_generated_files/varopt_sketch_long_n10000_cpp.sk b/serialization_test_data/cpp_generated_files/varopt_sketch_long_n10000_cpp.sk new file mode 100644 index 0000000..1337d0f Binary files /dev/null and b/serialization_test_data/cpp_generated_files/varopt_sketch_long_n10000_cpp.sk differ diff --git a/serialization_test_data/cpp_generated_files/varopt_sketch_long_n1000_cpp.sk b/serialization_test_data/cpp_generated_files/varopt_sketch_long_n1000_cpp.sk new file mode 100644 index 0000000..1ad81c9 Binary files /dev/null and b/serialization_test_data/cpp_generated_files/varopt_sketch_long_n1000_cpp.sk differ diff --git a/serialization_test_data/cpp_generated_files/varopt_sketch_long_n100_cpp.sk b/serialization_test_data/cpp_generated_files/varopt_sketch_long_n100_cpp.sk new file mode 100644 index 0000000..f779a78 Binary files /dev/null and b/serialization_test_data/cpp_generated_files/varopt_sketch_long_n100_cpp.sk differ diff --git a/serialization_test_data/cpp_generated_files/varopt_sketch_long_n10_cpp.sk b/serialization_test_data/cpp_generated_files/varopt_sketch_long_n10_cpp.sk new file mode 100644 index 0000000..f1ac8c0 Binary files /dev/null and b/serialization_test_data/cpp_generated_files/varopt_sketch_long_n10_cpp.sk differ diff --git a/serialization_test_data/cpp_generated_files/varopt_sketch_long_n1_cpp.sk b/serialization_test_data/cpp_generated_files/varopt_sketch_long_n1_cpp.sk new file mode 100644 index 0000000..86f6bc3 Binary files /dev/null and b/serialization_test_data/cpp_generated_files/varopt_sketch_long_n1_cpp.sk differ diff --git a/serialization_test_data/cpp_generated_files/varopt_sketch_long_sampling_cpp.sk b/serialization_test_data/cpp_generated_files/varopt_sketch_long_sampling_cpp.sk new file mode 100644 index 0000000..9ccfd97 Binary files /dev/null and b/serialization_test_data/cpp_generated_files/varopt_sketch_long_sampling_cpp.sk differ diff --git a/serialization_test_data/cpp_generated_files/varopt_sketch_string_exact_cpp.sk b/serialization_test_data/cpp_generated_files/varopt_sketch_string_exact_cpp.sk new file mode 100644 index 0000000..2da7e4e Binary files /dev/null and b/serialization_test_data/cpp_generated_files/varopt_sketch_string_exact_cpp.sk differ diff --git a/serialization_test_data/cpp_generated_files/varopt_union_double_sampling_cpp.sk b/serialization_test_data/cpp_generated_files/varopt_union_double_sampling_cpp.sk new file mode 100644 index 0000000..812892e Binary files /dev/null and b/serialization_test_data/cpp_generated_files/varopt_union_double_sampling_cpp.sk differ diff --git a/serialization_test_data/go_generated_files/varopt_sketch_long_n0_go.sk b/serialization_test_data/go_generated_files/varopt_sketch_long_n0_go.sk new file mode 100644 index 0000000..e4505fe Binary files /dev/null and b/serialization_test_data/go_generated_files/varopt_sketch_long_n0_go.sk differ diff --git a/serialization_test_data/go_generated_files/varopt_sketch_long_n1000000_go.sk b/serialization_test_data/go_generated_files/varopt_sketch_long_n1000000_go.sk new file mode 100644 index 0000000..be9ee7f Binary files /dev/null and b/serialization_test_data/go_generated_files/varopt_sketch_long_n1000000_go.sk differ diff --git a/serialization_test_data/go_generated_files/varopt_sketch_long_n100000_go.sk b/serialization_test_data/go_generated_files/varopt_sketch_long_n100000_go.sk new file mode 100644 index 0000000..356b8ce Binary files /dev/null and b/serialization_test_data/go_generated_files/varopt_sketch_long_n100000_go.sk differ diff --git a/serialization_test_data/go_generated_files/varopt_sketch_long_n10000_go.sk b/serialization_test_data/go_generated_files/varopt_sketch_long_n10000_go.sk new file mode 100644 index 0000000..e2276a2 Binary files /dev/null and b/serialization_test_data/go_generated_files/varopt_sketch_long_n10000_go.sk differ diff --git a/serialization_test_data/go_generated_files/varopt_sketch_long_n1000_go.sk b/serialization_test_data/go_generated_files/varopt_sketch_long_n1000_go.sk new file mode 100644 index 0000000..fdf0fbb Binary files /dev/null and b/serialization_test_data/go_generated_files/varopt_sketch_long_n1000_go.sk differ diff --git a/serialization_test_data/go_generated_files/varopt_sketch_long_n100_go.sk b/serialization_test_data/go_generated_files/varopt_sketch_long_n100_go.sk new file mode 100644 index 0000000..00bcb59 Binary files /dev/null and b/serialization_test_data/go_generated_files/varopt_sketch_long_n100_go.sk differ diff --git a/serialization_test_data/go_generated_files/varopt_sketch_long_n10_go.sk b/serialization_test_data/go_generated_files/varopt_sketch_long_n10_go.sk new file mode 100644 index 0000000..f1ac8c0 Binary files /dev/null and b/serialization_test_data/go_generated_files/varopt_sketch_long_n10_go.sk differ diff --git a/serialization_test_data/go_generated_files/varopt_sketch_long_n1_go.sk b/serialization_test_data/go_generated_files/varopt_sketch_long_n1_go.sk new file mode 100644 index 0000000..86f6bc3 Binary files /dev/null and b/serialization_test_data/go_generated_files/varopt_sketch_long_n1_go.sk differ diff --git a/serialization_test_data/go_generated_files/varopt_sketch_long_sampling_go.sk b/serialization_test_data/go_generated_files/varopt_sketch_long_sampling_go.sk new file mode 100644 index 0000000..ca94d5c Binary files /dev/null and b/serialization_test_data/go_generated_files/varopt_sketch_long_sampling_go.sk differ diff --git a/serialization_test_data/go_generated_files/varopt_sketch_string_exact_go.sk b/serialization_test_data/go_generated_files/varopt_sketch_string_exact_go.sk new file mode 100644 index 0000000..2da7e4e Binary files /dev/null and b/serialization_test_data/go_generated_files/varopt_sketch_string_exact_go.sk differ diff --git a/serialization_test_data/go_generated_files/varopt_union_double_sampling_go.sk b/serialization_test_data/go_generated_files/varopt_union_double_sampling_go.sk new file mode 100644 index 0000000..6152fcc Binary files /dev/null and b/serialization_test_data/go_generated_files/varopt_union_double_sampling_go.sk differ diff --git a/serialization_test_data/java_generated_files/varopt_sketch_long_n0_java.sk b/serialization_test_data/java_generated_files/varopt_sketch_long_n0_java.sk new file mode 100644 index 0000000..e4505fe Binary files /dev/null and b/serialization_test_data/java_generated_files/varopt_sketch_long_n0_java.sk differ diff --git a/serialization_test_data/java_generated_files/varopt_sketch_long_n1000000_java.sk b/serialization_test_data/java_generated_files/varopt_sketch_long_n1000000_java.sk new file mode 100644 index 0000000..b663f30 Binary files /dev/null and b/serialization_test_data/java_generated_files/varopt_sketch_long_n1000000_java.sk differ diff --git a/serialization_test_data/java_generated_files/varopt_sketch_long_n100000_java.sk b/serialization_test_data/java_generated_files/varopt_sketch_long_n100000_java.sk new file mode 100644 index 0000000..7df4970 Binary files /dev/null and b/serialization_test_data/java_generated_files/varopt_sketch_long_n100000_java.sk differ diff --git a/serialization_test_data/java_generated_files/varopt_sketch_long_n10000_java.sk b/serialization_test_data/java_generated_files/varopt_sketch_long_n10000_java.sk new file mode 100644 index 0000000..d120f49 Binary files /dev/null and b/serialization_test_data/java_generated_files/varopt_sketch_long_n10000_java.sk differ diff --git a/serialization_test_data/java_generated_files/varopt_sketch_long_n1000_java.sk b/serialization_test_data/java_generated_files/varopt_sketch_long_n1000_java.sk new file mode 100644 index 0000000..3a936e3 Binary files /dev/null and b/serialization_test_data/java_generated_files/varopt_sketch_long_n1000_java.sk differ diff --git a/serialization_test_data/java_generated_files/varopt_sketch_long_n100_java.sk b/serialization_test_data/java_generated_files/varopt_sketch_long_n100_java.sk new file mode 100644 index 0000000..eee954c Binary files /dev/null and b/serialization_test_data/java_generated_files/varopt_sketch_long_n100_java.sk differ diff --git a/serialization_test_data/java_generated_files/varopt_sketch_long_n10_java.sk b/serialization_test_data/java_generated_files/varopt_sketch_long_n10_java.sk new file mode 100644 index 0000000..f1ac8c0 Binary files /dev/null and b/serialization_test_data/java_generated_files/varopt_sketch_long_n10_java.sk differ diff --git a/serialization_test_data/java_generated_files/varopt_sketch_long_n1_java.sk b/serialization_test_data/java_generated_files/varopt_sketch_long_n1_java.sk new file mode 100644 index 0000000..86f6bc3 Binary files /dev/null and b/serialization_test_data/java_generated_files/varopt_sketch_long_n1_java.sk differ diff --git a/serialization_test_data/java_generated_files/varopt_sketch_long_sampling_java.sk b/serialization_test_data/java_generated_files/varopt_sketch_long_sampling_java.sk new file mode 100644 index 0000000..5e640c8 Binary files /dev/null and b/serialization_test_data/java_generated_files/varopt_sketch_long_sampling_java.sk differ diff --git a/serialization_test_data/java_generated_files/varopt_sketch_string_exact_java.sk b/serialization_test_data/java_generated_files/varopt_sketch_string_exact_java.sk new file mode 100644 index 0000000..2da7e4e Binary files /dev/null and b/serialization_test_data/java_generated_files/varopt_sketch_string_exact_java.sk differ diff --git a/serialization_test_data/java_generated_files/varopt_union_double_sampling_java.sk b/serialization_test_data/java_generated_files/varopt_union_double_sampling_java.sk new file mode 100644 index 0000000..edb18f4 Binary files /dev/null and b/serialization_test_data/java_generated_files/varopt_union_double_sampling_java.sk differ