Skip to content

Commit 3def49b

Browse files
authored
Merge pull request #724 from Fengzdadi/issue-721-varopt-rweight-validation
[Sampling] Validate VarOpt totalRWeight during heapify
2 parents e4f7a5b + b449e38 commit 3def49b

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

src/main/java/org/apache/datasketches/sampling/VarOptItemsSketch.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,11 @@ public static <T> VarOptItemsSketch<T> heapify(final MemorySegment srcSeg,
331331
if (numPreLongs == Family.VAROPT.getMaxPreLongs()) {
332332
if (rCount > 0) {
333333
totalRWeight = extractTotalRWeight(srcSeg);
334+
if (Double.isNaN(totalRWeight) || (totalRWeight <= 0.0)) {
335+
throw new SketchesArgumentException("Possible Corruption: deserializing in full mode "
336+
+ "but invalid R region weight. Found r = " + rCount
337+
+ ", R region weight = " + totalRWeight);
338+
}
334339
} else {
335340
throw new SketchesArgumentException(
336341
"Possible Corruption: "

src/test/java/org/apache/datasketches/sampling/VarOptItemsSketchTest.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,42 @@ public void checkCorruptSerializedWeight() {
288288
}
289289
}
290290

291+
@Test
292+
public void checkCorruptSerializedRWeightNaN() {
293+
final int k = 32;
294+
final VarOptItemsSketch<Long> sketch = getUnweightedLongsVIS(k, k + 1);
295+
final byte[] bytes = sketch.toByteArray(new ArrayOfLongsSerDe());
296+
final MemorySegment seg = MemorySegment.ofArray(bytes);
297+
assertEquals(PreambleUtil.extractPreLongs(seg), Family.VAROPT.getMaxPreLongs());
298+
299+
PreambleUtil.insertTotalRWeight(seg, Double.NaN);
300+
301+
try {
302+
VarOptItemsSketch.heapify(seg, new ArrayOfLongsSerDe());
303+
fail();
304+
} catch (final SketchesArgumentException e) {
305+
assertTrue(e.getMessage().contains("invalid R region weight"));
306+
}
307+
}
308+
309+
@Test
310+
public void checkCorruptSerializedRWeightZero() {
311+
final int k = 32;
312+
final VarOptItemsSketch<Long> sketch = getUnweightedLongsVIS(k, k + 1);
313+
final byte[] bytes = sketch.toByteArray(new ArrayOfLongsSerDe());
314+
final MemorySegment seg = MemorySegment.ofArray(bytes);
315+
assertEquals(PreambleUtil.extractPreLongs(seg), Family.VAROPT.getMaxPreLongs());
316+
317+
PreambleUtil.insertTotalRWeight(seg, 0.0);
318+
319+
try {
320+
VarOptItemsSketch.heapify(seg, new ArrayOfLongsSerDe());
321+
fail();
322+
} catch (final SketchesArgumentException e) {
323+
assertTrue(e.getMessage().contains("invalid R region weight"));
324+
}
325+
}
326+
291327
@Test
292328
public void checkCumulativeWeight() {
293329
final int k = 256;

0 commit comments

Comments
 (0)