Skip to content

Commit 4030940

Browse files
Merge pull request #8 from SaridakisStamatisChristos/codex/fix-test_merge_matches_extending-assertion-error
Ensure compaction retains boundary elements
2 parents d461b81 + 713b3ec commit 4030940

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

kll_sketch/kll_sketch.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -240,37 +240,43 @@ def _compress_once(self) -> bool:
240240

241241
buf = self._levels[lvl]
242242
buf.sort()
243+
if len(buf) < 3:
244+
# Not enough items to compact while preserving min/max boundaries.
245+
return False
246+
243247
rng = self._rng(salt=lvl + self._n + len(buf))
244248
keep_odd = rng.getrandbits(1) == 1
245249
start = 1 if keep_odd else 0
246250

247-
# Ensure we can form at least one pair; if not, flip parity once.
248-
if len(buf) - start < 2:
251+
# Always preserve explicit boundary elements.
252+
core = buf[1:-1]
253+
if len(core) < 2:
254+
return False
255+
256+
if len(core) - start < 2:
249257
keep_odd = not keep_odd
250258
start = 1 if keep_odd else 0
251-
if len(buf) - start < 2:
259+
if len(core) - start < 2:
252260
return False
253261

254262
promoted: List[float] = []
255-
# True KLL: choose one from each adjacent pair (unbiased)
256-
for i in range(start, len(buf) - 1, 2):
257-
promoted.append(buf[i] if rng.getrandbits(1) else buf[i + 1])
263+
for i in range(start, len(core) - 1, 2):
264+
promoted.append(core[i] if rng.getrandbits(1) else core[i + 1])
265+
266+
if not promoted:
267+
return False
258268

259-
# Boundary preservation: keep BOTH non-paired boundaries.
260-
leftover: List[float] = []
269+
leftover: List[float] = [buf[0]]
261270
if start == 1:
262-
# Front boundary not in any pair
263-
leftover.append(buf[0])
264-
if (len(buf) - start) % 2 == 1:
265-
# Tail boundary not in any pair
266-
tail = buf[-1]
267-
if not leftover or leftover[-1] != tail:
268-
leftover.append(tail)
271+
leftover.append(core[0])
272+
if (len(core) - start) % 2 == 1:
273+
leftover.append(core[-1])
274+
leftover.append(buf[-1])
269275

270276
self._levels[lvl] = leftover
271277
self._ensure_levels(lvl + 2)
272278
self._levels[lvl + 1].extend(promoted)
273-
return len(promoted) > 0
279+
return True
274280

275281
def _compress_until_ok(self) -> None:
276282
loops = 0

0 commit comments

Comments
 (0)