Skip to content

Commit 7a27a25

Browse files
authored
Merge pull request #127 from proost/refactor-remove-panic-in-cpc
refactor: remove panic in CPC sketch
2 parents e389956 + 3e027ce commit 7a27a25

2 files changed

Lines changed: 80 additions & 45 deletions

File tree

cpc/cpc_compressed_state.go

Lines changed: 70 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package cpc
1919

2020
import (
21+
"errors"
2122
"fmt"
2223
"math/bits"
2324

@@ -271,7 +272,10 @@ func (c *CpcCompressedState) compressHybridFlavor(src *CpcSketch) error {
271272
return fmt.Errorf("compressHybridFlavor: invariant violation (%d + %d != %d)",
272273
numPairsFromArray, srcNumPairs, srcNumCoupons)
273274
}
274-
allPairs := trickyGetPairsFromWindow(srcSlidingWindow, srcK, numPairsFromArray, srcNumPairs)
275+
allPairs, err := trickyGetPairsFromWindow(srcSlidingWindow, srcK, numPairsFromArray, srcNumPairs)
276+
if err != nil {
277+
return err
278+
}
275279
mergePairs(srcPairArr, 0, srcNumPairs, allPairs, srcNumPairs, numPairsFromArray, allPairs, 0)
276280
return compressTheSurprisingValues(c, src, allPairs, int(srcNumCoupons))
277281
}
@@ -428,7 +432,10 @@ func (c *CpcCompressedState) compressSlidingFlavor(src *CpcSketch) error {
428432
}
429433

430434
// Apply a transformation to the column indices.
431-
pseudoPhase := determinePseudoPhase(src.lgK, int64(src.numCoupons))
435+
pseudoPhase, err := determinePseudoPhase(src.lgK, int64(src.numCoupons))
436+
if err != nil {
437+
return err
438+
}
432439
if pseudoPhase >= 16 {
433440
return fmt.Errorf("compressSlidingFlavor: pseudoPhase (%d) >= 16", pseudoPhase)
434441
}
@@ -492,7 +499,10 @@ func (c *CpcCompressedState) uncompressSlidingFlavor(src *CpcSketch) error {
492499
}
493500

494501
// Determine pseudoPhase.
495-
pseudoPhase := determinePseudoPhase(srcLgK, int64(c.NumCoupons))
502+
pseudoPhase, err := determinePseudoPhase(srcLgK, int64(c.NumCoupons))
503+
if err != nil {
504+
return err
505+
}
496506
if pseudoPhase >= 16 {
497507
return fmt.Errorf("uncompressSlidingFlavor: pseudoPhase %d out of range", pseudoPhase)
498508
}
@@ -608,7 +618,7 @@ func importFromMemory(bytes []byte) (*CpcCompressedState, error) {
608618
state.CwStream = getWStream(bytes)
609619
state.CsvStream = getSvStream(bytes)
610620
default:
611-
panic("not implemented")
621+
return nil, fmt.Errorf("unknown format: %d", format)
612622
}
613623
return state, nil
614624
}
@@ -675,14 +685,23 @@ func compressTheSurprisingValues(target *CpcCompressedState, source *CpcSketch,
675685
// Compute srcK = 1 << source.lgK.
676686
srcK := 1 << source.lgK
677687
// Determine the number of base bits using a Golomb code decision.
678-
numBaseBits := golombChooseNumberOfBaseBits(srcK+numPairs, numPairs)
688+
numBaseBits, err := golombChooseNumberOfBaseBits(srcK+numPairs, numPairs)
689+
if err != nil {
690+
return err
691+
}
679692
// Compute an upper-bound length for the compressed pairs buffer.
680-
pairBufLen := safeLengthForCompressedPairBuf(srcK, numPairs, numBaseBits)
693+
pairBufLen, err := safeLengthForCompressedPairBuf(srcK, numPairs, numBaseBits)
694+
if err != nil {
695+
return err
696+
}
681697
// Allocate the buffer for compression.
682698
pairBuf := make([]int, pairBufLen)
683699
// lowLevelCompressPairs compresses 'pairs' using the chosen base bits into pairBuf.
684700
// It returns the number of ints that represent the compressed data.
685-
csvLength := lowLevelCompressPairs(pairs, numPairs, numBaseBits, pairBuf)
701+
csvLength, err := lowLevelCompressPairs(pairs, numPairs, numBaseBits, pairBuf)
702+
if err != nil {
703+
return err
704+
}
686705
target.CsvLengthInts = csvLength
687706
target.CsvStream = pairBuf
688707
return nil
@@ -696,32 +715,35 @@ func uncompressTheSurprisingValues(source *CpcCompressedState) ([]int, error) {
696715
}
697716
pairs := make([]int, numPairs)
698717
// Determine the number of base bits using the Golomb code decision.
699-
numBaseBits := golombChooseNumberOfBaseBits(srcK+numPairs, numPairs)
718+
numBaseBits, err := golombChooseNumberOfBaseBits(srcK+numPairs, numPairs)
719+
if err != nil {
720+
return nil, err
721+
}
700722
// lowLevelUncompressPairs fills the 'pairs' slice using the compressed CSV stream.
701723
if err := lowLevelUncompressPairs(pairs, numPairs, numBaseBits, source.CsvStream, source.CsvLengthInts); err != nil {
702724
return nil, err
703725
}
704726
return pairs, nil
705727
}
706728

707-
func golombChooseNumberOfBaseBits(k, count int) int {
729+
func golombChooseNumberOfBaseBits(k, count int) (int, error) {
708730
if k < 1 || count < 1 {
709-
panic("golombChooseNumberOfBaseBits: k and count must be >= 1")
731+
return 0, errors.New("golombChooseNumberOfBaseBits: k and count must be >= 1")
710732
}
711733
quotient := (k - count) / count
712734
if quotient == 0 {
713-
return 0
735+
return 0, nil
714736
}
715-
return floorLog2(uint64(quotient))
737+
return floorLog2(uint64(quotient)), nil
716738
}
717739

718740
func floorLog2(x uint64) int {
719741
return bits.Len64(x) - 1
720742
}
721743

722-
func safeLengthForCompressedPairBuf(k, numPairs, numBaseBits int) int {
744+
func safeLengthForCompressedPairBuf(k, numPairs, numBaseBits int) (int, error) {
723745
if numPairs <= 0 {
724-
panic("safeLengthForCompressedPairBuf: numPairs must be > 0")
746+
return 0, errors.New("safeLengthForCompressedPairBuf: numPairs must be > 0")
725747
}
726748
// Compute ybits = (numPairs * (1 + numBaseBits)) + (k >>> numBaseBits)
727749
ybits := int64(numPairs)*(1+int64(numBaseBits)) + (int64(k) >> uint(numBaseBits))
@@ -736,9 +758,9 @@ func safeLengthForCompressedPairBuf(k, numPairs, numBaseBits int) int {
736758
words := divideBy32RoundingUp(totalBits)
737759
// Ensure the number of words fits in a 31-bit int.
738760
if words >= (1 << 31) {
739-
panic("safeLengthForCompressedPairBuf: words too large")
761+
return 0, errors.New("safeLengthForCompressedPairBuf: words too large")
740762
}
741-
return int(words)
763+
return int(words), nil
742764
}
743765

744766
func divideBy32RoundingUp(x int64) int64 {
@@ -749,7 +771,7 @@ func divideBy32RoundingUp(x int64) int64 {
749771
return tmp + 1
750772
}
751773

752-
func lowLevelCompressPairs(pairArray []int, numPairsToEncode, numBaseBits int, compressedWords []int) int {
774+
func lowLevelCompressPairs(pairArray []int, numPairsToEncode, numBaseBits int, compressedWords []int) (int, error) {
753775
nextWordIndex := 0
754776
var bitBuf uint64 = 0
755777
bufBits := 0
@@ -773,8 +795,7 @@ func lowLevelCompressPairs(pairArray []int, numPairsToEncode, numBaseBits int, c
773795
predictedColIndex = 0
774796
}
775797
if rowIndex < predictedRowIndex || colIndex < predictedColIndex {
776-
panic(fmt.Sprintf("lowLevelCompressPairs: assertion failed: rowIndex=%d, predictedRowIndex=%d, colIndex=%d, predictedColIndex=%d",
777-
rowIndex, predictedRowIndex, colIndex, predictedColIndex))
798+
return 0, fmt.Errorf("lowLevelCompressPairs: assertion failed: rowIndex=%d, predictedRowIndex=%d, colIndex=%d, predictedColIndex=%d", rowIndex, predictedRowIndex, colIndex, predictedColIndex)
778799
}
779800

780801
// yDelta is the difference in row indices.
@@ -846,7 +867,7 @@ func lowLevelCompressPairs(pairArray []int, numPairsToEncode, numBaseBits int, c
846867
compressedWords[nextWordIndex] = int(bitBuf & 0xFFFFFFFF)
847868
nextWordIndex++
848869
}
849-
return nextWordIndex
870+
return nextWordIndex, nil
850871
}
851872

852873
func lowLevelUncompressPairs(pairArray []int, numPairsToDecode, numBaseBits int, compressedWords []int, numCompressedWords int) error {
@@ -889,7 +910,10 @@ func lowLevelUncompressPairs(pairArray []int, numPairsToDecode, numBaseBits int,
889910
ptrArr[NextWordIdx] = int64(nextWordIndex)
890911
ptrArr[BitBuf] = int64(bitBuf)
891912
ptrArr[BufBits] = int64(bufBits)
892-
golombHi := readUnary(compressedWords, ptrArr)
913+
golombHi, err := readUnary(compressedWords, ptrArr)
914+
if err != nil {
915+
return err
916+
}
893917
// Retrieve updated values.
894918
nextWordIndex = int(ptrArr[NextWordIdx])
895919
bitBuf = uint64(ptrArr[BitBuf])
@@ -931,7 +955,7 @@ func lowLevelUncompressPairs(pairArray []int, numPairsToDecode, numBaseBits int,
931955
return nil
932956
}
933957

934-
func readUnary(compressedWords []int, ptrArr []int64) int64 {
958+
func readUnary(compressedWords []int, ptrArr []int64) (int64, error) {
935959
nextWordIndex := int(ptrArr[NextWordIdx])
936960
bitBuf := uint64(ptrArr[BitBuf])
937961
bufBits := int(ptrArr[BufBits])
@@ -944,7 +968,7 @@ func readUnary(compressedWords []int, ptrArr []int64) int64 {
944968
// Ensure we have at least 8 bits in the bit buffer.
945969
if bufBits < 8 {
946970
if nextWordIndex >= len(compressedWords) {
947-
panic("readUnary: insufficient compressedWords data")
971+
return 0, errors.New("readUnary: insufficient compressedWords data")
948972
}
949973
bitBuf |= (uint64(compressedWords[nextWordIndex]) & 0xFFFFFFFF) << uint(bufBits)
950974
nextWordIndex++
@@ -975,7 +999,7 @@ func readUnary(compressedWords []int, ptrArr []int64) int64 {
975999
ptrArr[BitBuf] = int64(bitBuf)
9761000
ptrArr[BufBits] = int64(bufBits)
9771001

978-
return subTotal + int64(trailingZeros)
1002+
return subTotal + int64(trailingZeros), nil
9791003
}
9801004

9811005
func writeUnary(compressedWords []int, ptrArr []int64, theValue int) {
@@ -1011,7 +1035,7 @@ func writeUnary(compressedWords []int, ptrArr []int64, theValue int) {
10111035
ptrArr[BufBits] = int64(bufBits)
10121036
}
10131037

1014-
func trickyGetPairsFromWindow(window []byte, k, numPairsToGet, emptySpace int) []int {
1038+
func trickyGetPairsFromWindow(window []byte, k, numPairsToGet, emptySpace int) ([]int, error) {
10151039
outputLength := emptySpace + numPairsToGet
10161040
pairs := make([]int, outputLength)
10171041
pairIndex := emptySpace
@@ -1031,10 +1055,10 @@ func trickyGetPairsFromWindow(window []byte, k, numPairsToGet, emptySpace int) [
10311055
}
10321056

10331057
if pairIndex != outputLength {
1034-
panic(fmt.Sprintf("trickyGetPairsFromWindow: pairIndex (%d) != outputLength (%d)", pairIndex, outputLength))
1058+
return nil, fmt.Errorf("trickyGetPairsFromWindow: pairIndex (%d) != outputLength (%d)", pairIndex, outputLength)
10351059
}
10361060

1037-
return pairs
1061+
return pairs, nil
10381062
}
10391063

10401064
func (c *CpcCompressedState) compressTheWindow(src *CpcSketch) error {
@@ -1045,7 +1069,10 @@ func (c *CpcCompressedState) compressTheWindow(src *CpcSketch) error {
10451069
windowBufLen := safeLengthForCompressedWindowBuf(int64(srcK))
10461070
windowBuf := make([]int, windowBufLen)
10471071
// Determine the pseudo-phase using srcLgK and the number of coupons.
1048-
pseudoPhase := determinePseudoPhase(srcLgK, int64(src.numCoupons))
1072+
pseudoPhase, err := determinePseudoPhase(srcLgK, int64(src.numCoupons))
1073+
if err != nil {
1074+
return err
1075+
}
10491076
// Compress the sliding window bytes.
10501077
// lowLevelCompressBytes is assumed to return (cwLengthInts int, err error).
10511078
cwLengthInts := lowLevelCompressBytes(src.slidingWindow, srcK, encodingTablesForHighEntropyByte[pseudoPhase], windowBuf)
@@ -1069,7 +1096,10 @@ func uncompressTheWindow(target *CpcSketch, source *CpcCompressedState) error {
10691096
target.slidingWindow = window
10701097

10711098
// Determine the pseudo-phase using srcLgK and source.NumCoupons.
1072-
pseudoPhase := determinePseudoPhase(srcLgK, int64(source.NumCoupons))
1099+
pseudoPhase, err := determinePseudoPhase(srcLgK, int64(source.NumCoupons))
1100+
if err != nil {
1101+
return err
1102+
}
10731103
// Ensure that source.CwStream is not nil.
10741104
if source.CwStream == nil {
10751105
return fmt.Errorf("uncompressTheWindow: source.CwStream is nil")
@@ -1091,37 +1121,37 @@ func safeLengthForCompressedWindowBuf(k int64) int {
10911121
return int(divideBy32RoundingUp(totalBits))
10921122
}
10931123

1094-
func determinePseudoPhase(lgK int, numCoupons int64) int {
1124+
func determinePseudoPhase(lgK int, numCoupons int64) (int, error) {
10951125
k := int64(1) << uint(lgK)
10961126
c := numCoupons
10971127
// Midrange logic.
10981128
if (1000 * c) < (2375 * k) {
10991129
if (4 * c) < (3 * k) {
1100-
return 16 + 0
1130+
return 16 + 0, nil
11011131
} else if (10 * c) < (11 * k) {
1102-
return 16 + 1
1132+
return 16 + 1, nil
11031133
} else if (100 * c) < (132 * k) {
1104-
return 16 + 2
1134+
return 16 + 2, nil
11051135
} else if (3 * c) < (5 * k) {
1106-
return 16 + 3
1136+
return 16 + 3, nil
11071137
} else if (1000 * c) < (1965 * k) {
1108-
return 16 + 4
1138+
return 16 + 4, nil
11091139
} else if (1000 * c) < (2275 * k) {
1110-
return 16 + 5
1140+
return 16 + 5, nil
11111141
} else {
1112-
return 6 // steady-state table employed before its actual phase.
1142+
return 6, nil // steady-state table employed before its actual phase.
11131143
}
11141144
} else {
11151145
// Steady-state logic.
11161146
if lgK < 4 {
1117-
panic("determinePseudoPhase: lgK must be at least 4")
1147+
return 0, errors.New("determinePseudoPhase: lgK must be at least 4")
11181148
}
11191149
tmp := c >> uint(lgK-4)
11201150
phase := int(tmp & 15)
11211151
if phase < 0 || phase >= 16 {
1122-
panic(fmt.Sprintf("determinePseudoPhase: phase out of range: %d", phase))
1152+
return 0, fmt.Errorf("determinePseudoPhase: phase out of range: %d", phase)
11231153
}
1124-
return phase
1154+
return phase, nil
11251155
}
11261156
}
11271157

cpc/cpc_compressed_state_test.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import (
2323
"sort"
2424
"testing"
2525

26+
"github.com/stretchr/testify/assert"
27+
2628
"github.com/apache/datasketches-go/internal"
2729
)
2830

@@ -84,7 +86,8 @@ func TestWriteReadUnary(t *testing.T) {
8486
if nextWordIndex != int(ptrArr[NextWordIdx]) {
8587
t.Errorf("Before readUnary: nextWordIndex %d != ptrArr[NextWordIdx] %d", nextWordIndex, ptrArr[NextWordIdx])
8688
}
87-
result := readUnary(compressedWords, ptrArr)
89+
result, err := readUnary(compressedWords, ptrArr)
90+
assert.NoError(t, err)
8891
t.Logf("Result: %d, expected: %d", result, i)
8992
if result != int64(i) {
9093
t.Errorf("Mismatch: got %d, expected %d", result, i)
@@ -170,9 +173,10 @@ func TestWriteReadPairs(t *testing.T) {
170173
compressedWords := make([]int, MaxWords)
171174
// Loop over base bits 0 to 11.
172175
for bb := 0; bb <= 11; bb++ {
173-
numWordsWritten := lowLevelCompressPairs(pairArray, numPairs, bb, compressedWords)
176+
numWordsWritten, err := lowLevelCompressPairs(pairArray, numPairs, bb, compressedWords)
177+
assert.NoError(t, err)
174178
t.Logf("numWordsWritten = %d, bb = %d", numWordsWritten, bb)
175-
err := lowLevelUncompressPairs(pairArray2, numPairs, bb, compressedWords, numWordsWritten)
179+
err = lowLevelUncompressPairs(pairArray2, numPairs, bb, compressedWords, numWordsWritten)
176180
if err != nil {
177181
t.Errorf("Error in lowLevelUncompressPairs for bb=%d: %v", bb, err)
178182
}
@@ -390,9 +394,10 @@ func TestWriteReadPairsExtended(t *testing.T) {
390394
compressedWords := make([]int, MaxWords)
391395
// Loop over base bits 0 to 11.
392396
for bb := 0; bb <= 11; bb++ {
393-
numWordsWritten := lowLevelCompressPairs(pairArray, numPairs, bb, compressedWords)
397+
numWordsWritten, err := lowLevelCompressPairs(pairArray, numPairs, bb, compressedWords)
398+
assert.NoError(t, err)
394399
t.Logf("Base bits: %d, words written: %d", bb, numWordsWritten)
395-
err := lowLevelUncompressPairs(pairArray2, numPairs, bb, compressedWords, numWordsWritten)
400+
err = lowLevelUncompressPairs(pairArray2, numPairs, bb, compressedWords, numWordsWritten)
396401
if err != nil {
397402
t.Errorf("Error in lowLevelUncompressPairs for base bits %d: %v", bb, err)
398403
}

0 commit comments

Comments
 (0)