Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 116 additions & 55 deletions src/main/java/org/apache/datasketches/count/CountMinSketch.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,19 @@
import org.apache.datasketches.common.Family;
import org.apache.datasketches.common.SketchesArgumentException;
import org.apache.datasketches.common.SketchesException;
import org.apache.datasketches.common.Util;
import org.apache.datasketches.hash.MurmurHash3;
import org.apache.datasketches.tuple.Util;

import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.lang.foreign.MemorySegment;
import java.nio.charset.StandardCharsets;

import static java.lang.foreign.ValueLayout.JAVA_BYTE;
import java.util.Random;

import static java.lang.foreign.ValueLayout.JAVA_INT_UNALIGNED;
import static java.lang.foreign.ValueLayout.JAVA_LONG_UNALIGNED;
import static java.lang.foreign.ValueLayout.JAVA_SHORT_UNALIGNED;


public class CountMinSketch {
private final byte numHashes_;
Comment thread
freakyzoidberg marked this conversation as resolved.
Expand All @@ -39,6 +44,9 @@ public class CountMinSketch {
private final long[] sketchArray_;
private long totalWeight_;

// Thread-local MemorySegment to avoid allocations in hot paths with explicit endianness control
private static final ThreadLocal<MemorySegment> LONG_SEGMENT =
ThreadLocal.withInitial(() -> MemorySegment.ofArray(new byte[8]));

private enum Flag {
IS_EMPTY;
Expand All @@ -57,30 +65,59 @@ int mask() {
* @param seed The base hash seed
*/
CountMinSketch(final byte numHashes, final int numBuckets, final long seed) {
numHashes_ = numHashes;
numBuckets_ = numBuckets;
seed_ = seed;
hashSeeds_ = new long[numHashes];
sketchArray_ = new long[numHashes * numBuckets];
totalWeight_ = 0;
// Validate numHashes
if (numHashes <= 0) {
throw new SketchesArgumentException("Number of hash functions must be positive, got: " + numHashes);
}

// Validate numBuckets with clear mathematical justification
if (numBuckets <= 0) {
throw new SketchesArgumentException("Number of buckets must be positive, got: " + numBuckets);
}
if (numBuckets < 3) {
throw new SketchesArgumentException("Using fewer than 3 buckets incurs relative error greater than 1.");
throw new SketchesArgumentException("Number of buckets must be at least 3 to ensure relative error ≤ 1.0. " +
"With " + numBuckets + " buckets, relative error would be " + String.format("%.3f", Math.exp(1.0) / numBuckets));
Comment thread
freakyzoidberg marked this conversation as resolved.
}

// Check for potential overflow in array size calculation
// Use long arithmetic to detect overflow before casting
final long totalSize = (long) numHashes * (long) numBuckets;
if (totalSize > Integer.MAX_VALUE) {
throw new SketchesArgumentException("Sketch array size would overflow: " + numHashes + " * " + numBuckets +
" = " + totalSize + " > " + Integer.MAX_VALUE);
}

// This check is to ensure later compatibility with a Java implementation whose maximum size can only
// be 2^31-1. We check only against 2^30 for simplicity.
if (numBuckets * numHashes >= 1 << 30) {
throw new SketchesArgumentException("These parameters generate a sketch that exceeds 2^30 elements. \n" +
"Try reducing either the number of buckets or the number of hash functions.");
if (totalSize >= (1L << 30)) {
throw new SketchesArgumentException("Sketch would require excessive memory: " + numHashes + " * " + numBuckets +
" = " + totalSize + " elements (~" + String.format("%.1f", totalSize * 8.0 / (1024 * 1024 * 1024)) + " GB). " +
"Consider reducing numHashes or numBuckets.");
}

numHashes_ = numHashes;
numBuckets_ = numBuckets;
seed_ = seed;
hashSeeds_ = new long[numHashes];
sketchArray_ = new long[(int) totalSize];
totalWeight_ = 0;

Random rand = new Random(seed);
for (int i = 0; i < numHashes; i++) {
Comment thread
freakyzoidberg marked this conversation as resolved.
hashSeeds_[i] = rand.nextLong();
}
}

/**
* Efficiently converts a long to byte array using thread-local MemorySegment with explicit endianness.
*/
private static byte[] longToBytes(final long value) {
final MemorySegment segment = LONG_SEGMENT.get();
segment.set(JAVA_LONG_UNALIGNED, 0, value);
return segment.toArray(JAVA_BYTE);
}


private long[] getHashes(byte[] item) {
long[] updateLocations = new long[numHashes_];

Expand Down Expand Up @@ -171,8 +208,7 @@ public static int suggestNumBuckets(double relativeError) {
* @param weight The weight of the item.
*/
public void update(final long item, final long weight) {
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
update(longByte, weight);
update(longToBytes(item), weight);
}

/**
Expand Down Expand Up @@ -211,8 +247,7 @@ public void update(final byte[] item, final long weight) {
* @return Estimated frequency.
*/
public long getEstimate(final long item) {
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
return getEstimate(longByte);
return getEstimate(longToBytes(item));
}

/**
Expand Down Expand Up @@ -241,8 +276,9 @@ public long getEstimate(final byte[] item) {

long[] hashLocations = getHashes(item);
long res = sketchArray_[(int) hashLocations[0]];
for (long h : hashLocations) {
res = Math.min(res, sketchArray_[(int) h]);
// Start from index 1 to avoid processing first element twice
for (int i = 1; i < hashLocations.length; i++) {
res = Math.min(res, sketchArray_[(int) hashLocations[i]]);
}

return res;
Expand All @@ -254,8 +290,7 @@ public long getEstimate(final byte[] item) {
* @return Upper bound of estimated frequency.
*/
public long getUpperBound(final long item) {
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
return getUpperBound(longByte);
return getUpperBound(longToBytes(item));
}

/**
Expand Down Expand Up @@ -291,8 +326,7 @@ public long getUpperBound(final byte[] item) {
* @return Lower bound of estimated frequency.
*/
public long getLowerBound(final long item) {
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
return getLowerBound(longByte);
return getLowerBound(longToBytes(item));
}

/**
Expand Down Expand Up @@ -342,39 +376,62 @@ public void merge(final CountMinSketch other) {
}

/**
* Serializes the sketch into the provided ByteBuffer.
* @param buf The ByteBuffer to write into.
* Returns the serialized size in bytes.
*/
private int getSerializedSizeBytes() {
final int preambleBytes = Family.COUNTMIN.getMinPreLongs() * Long.BYTES;
if (isEmpty()) {
return preambleBytes;
}
return preambleBytes + Long.BYTES + (sketchArray_.length * Long.BYTES);
}


/**
* Returns the sketch as a byte array.
*/
public void serialize(ByteArrayOutputStream buf) {
public byte[] toByteArray() {
final int serializedSizeBytes = getSerializedSizeBytes();
final MemorySegment wseg = MemorySegment.ofArray(new byte[serializedSizeBytes]);

long offset = 0;

// Long 0
final int preambleLongs = Family.COUNTMIN.getMinPreLongs();
buf.write((byte) preambleLongs);
wseg.set(JAVA_BYTE, offset++, (byte) preambleLongs);
final int serialVersion = 1;
buf.write((byte) serialVersion);
wseg.set(JAVA_BYTE, offset++, (byte) serialVersion);
final int familyId = Family.COUNTMIN.getID();
buf.write((byte) familyId);
wseg.set(JAVA_BYTE, offset++, (byte) familyId);
final int flagsByte = isEmpty() ? Flag.IS_EMPTY.mask() : 0;
buf.write((byte)flagsByte);
wseg.set(JAVA_BYTE, offset++, (byte) flagsByte);
final int NULL_32 = 0;
buf.writeBytes(ByteBuffer.allocate(4).putInt(NULL_32).array());
wseg.set(JAVA_INT_UNALIGNED, offset, NULL_32);
offset += 4;
Copy link
Copy Markdown
Member

@leerho leerho Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (and following) is a use-case where the new datasketches.common.positional.PositionalSegment could be used. You might want to look at it. I used the term "positional" instead of "buffer", because "buffer" is a widely used generic term and doesn't convey what is actually going on. The PS would eliminate all of the offset tracking that you had to do by hand (with magic numbers).

I had already migrated theta and tuple, but when I got to bloomfilter, I decided there was a real need for it -- so I paused and created it. It is already used in bloomfilter, tdigest, kll, and req. I intend to go back and use it in the rest of the library where it makes sense. And here it makes perfect sense, but this suggestion is optional.

Nonetheless, if you decide not to use PS, I would recommend not using magic numbers for the offset adjustments and use the build-in static constants such as Integer.BYTES, Long.BYTES, etc. The code will become more obvious as to what you are doing and more robust. :-)

Again, thank you for your work on this!


// Long 1
buf.writeBytes(ByteBuffer.allocate(4).putInt(numBuckets_).array());
buf.write(numHashes_);
wseg.set(JAVA_INT_UNALIGNED, offset, numBuckets_);
offset += 4;
wseg.set(JAVA_BYTE, offset++, numHashes_);
short hashSeed = Util.computeSeedHash(seed_);
buf.writeBytes(ByteBuffer.allocate(2).putShort(hashSeed).array());
wseg.set(JAVA_SHORT_UNALIGNED, offset, hashSeed);
offset += 2;
final byte NULL_8 = 0;
buf.write(NULL_8);
wseg.set(JAVA_BYTE, offset++, NULL_8);

if (isEmpty()) {
return;
return wseg.toArray(JAVA_BYTE);
}

final byte[] totWeightByte = ByteBuffer.allocate(8).putLong(totalWeight_).array();
buf.writeBytes(totWeightByte);
wseg.set(JAVA_LONG_UNALIGNED, offset, totalWeight_);
offset += 8;

for (long w: sketchArray_) {
buf.writeBytes(ByteBuffer.allocate(8).putLong(w).array());
wseg.set(JAVA_LONG_UNALIGNED, offset, w);
offset += 8;
}

return wseg.toArray(JAVA_BYTE);
}

/**
Expand All @@ -384,20 +441,22 @@ public void serialize(ByteArrayOutputStream buf) {
* @return The deserialized CountMinSketch.
*/
public static CountMinSketch deserialize(final byte[] b, final long seed) {
ByteBuffer buf = ByteBuffer.allocate(b.length);
buf.put(b);
buf.flip();

final byte preambleLongs = buf.get();
final byte serialVersion = buf.get();
final byte familyId = buf.get();
final byte flagsByte = buf.get();
final int NULL_32 = buf.getInt();

final int numBuckets = buf.getInt();
final byte numHashes = buf.get();
final short seedHash = buf.getShort();
final byte NULL_8 = buf.get();
final MemorySegment buf = MemorySegment.ofArray(b);
long offset = 0;

final byte preambleLongs = buf.get(JAVA_BYTE, offset++);
final byte serialVersion = buf.get(JAVA_BYTE, offset++);
final byte familyId = buf.get(JAVA_BYTE, offset++);
final byte flagsByte = buf.get(JAVA_BYTE, offset++);
final int NULL_32 = buf.get(JAVA_INT_UNALIGNED, offset);
offset += 4;

final int numBuckets = buf.get(JAVA_INT_UNALIGNED, offset);
offset += 4;
final byte numHashes = buf.get(JAVA_BYTE, offset++);
final short seedHash = buf.get(JAVA_SHORT_UNALIGNED, offset);
offset += 2;
final byte NULL_8 = buf.get(JAVA_BYTE, offset++);

Comment thread
freakyzoidberg marked this conversation as resolved.
if (seedHash != Util.computeSeedHash(seed)) {
throw new SketchesArgumentException("Incompatible seed hashes: " + String.valueOf(seedHash) + ", "
Expand All @@ -409,11 +468,13 @@ public static CountMinSketch deserialize(final byte[] b, final long seed) {
if (empty) {
return cms;
}
long w = buf.getLong();
long w = buf.get(JAVA_LONG_UNALIGNED, offset);
offset += 8;
cms.totalWeight_ = w;

for (int i = 0; i < cms.sketchArray_.length; i++) {
cms.sketchArray_[i] = buf.getLong();
cms.sketchArray_[i] = buf.get(JAVA_LONG_UNALIGNED, offset);
offset += 8;
}

return cms;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,7 @@ public void serializeDeserializeEmptyTest() {
final long seed = 123456;
CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed);

ByteArrayOutputStream buf = new ByteArrayOutputStream();
c.serialize(buf);

byte[] b = buf.toByteArray();
byte[] b = c.toByteArray();
assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(b, seed - 1));

CountMinSketch d = CountMinSketch.deserialize(b, seed);
Expand All @@ -228,11 +225,10 @@ public void serializeDeserializeTest() {
c.update(i, 10*i*i);
}

ByteArrayOutputStream buf = new ByteArrayOutputStream();
c.serialize(buf);
byte[] b = c.toByteArray();

assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(buf.toByteArray(), seed - 1));
CountMinSketch d = CountMinSketch.deserialize(buf.toByteArray(), seed);
assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(b, seed - 1));
CountMinSketch d = CountMinSketch.deserialize(b, seed);

assertEquals(d.getNumHashes_(), c.getNumHashes_());
assertEquals(d.getNumBuckets_(), c.getNumBuckets_());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public void checkAllFlavorsGo() throws IOException {
int flavorIdx = 0;
for (int n: nArr) {
final byte[] bytes = Files.readAllBytes(goPath.resolve("cpc_n" + n + "_go.sk"));
final CpcSketch sketch = CpcSketch.heapify(Memory.wrap(bytes));
final CpcSketch sketch = CpcSketch.heapify(MemorySegment.ofArray(bytes));
assertEquals(sketch.getFlavor(), flavorArr[flavorIdx++]);
Comment thread
freakyzoidberg marked this conversation as resolved.
assertEquals(sketch.getEstimate(), n, n * 0.02);
}
Expand Down