-
Notifications
You must be signed in to change notification settings - Fork 221
CMS and CPCxLangTest move to FFM, Build fix #676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
c091214
c6c90c3
e0445ef
ea9fda9
8bd1423
66f1333
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,14 +22,19 @@ | |
| import org.apache.datasketches.common.Family; | ||
| import org.apache.datasketches.common.SketchesArgumentException; | ||
| import org.apache.datasketches.common.SketchesException; | ||
| import org.apache.datasketches.common.Util; | ||
| import org.apache.datasketches.hash.MurmurHash3; | ||
| import org.apache.datasketches.tuple.Util; | ||
|
|
||
| import java.io.ByteArrayOutputStream; | ||
| import java.nio.ByteBuffer; | ||
| import java.lang.foreign.MemorySegment; | ||
| import java.nio.charset.StandardCharsets; | ||
|
|
||
| import static java.lang.foreign.ValueLayout.JAVA_BYTE; | ||
| import java.util.Random; | ||
|
|
||
| import static java.lang.foreign.ValueLayout.JAVA_INT_UNALIGNED; | ||
| import static java.lang.foreign.ValueLayout.JAVA_LONG_UNALIGNED; | ||
| import static java.lang.foreign.ValueLayout.JAVA_SHORT_UNALIGNED; | ||
|
|
||
|
|
||
| public class CountMinSketch { | ||
| private final byte numHashes_; | ||
|
|
@@ -39,6 +44,9 @@ public class CountMinSketch { | |
| private final long[] sketchArray_; | ||
| private long totalWeight_; | ||
|
|
||
| // Thread-local MemorySegment to avoid allocations in hot paths with explicit endianness control | ||
| private static final ThreadLocal<MemorySegment> LONG_SEGMENT = | ||
| ThreadLocal.withInitial(() -> MemorySegment.ofArray(new byte[8])); | ||
|
|
||
| private enum Flag { | ||
| IS_EMPTY; | ||
|
|
@@ -57,30 +65,59 @@ int mask() { | |
| * @param seed The base hash seed | ||
| */ | ||
| CountMinSketch(final byte numHashes, final int numBuckets, final long seed) { | ||
| numHashes_ = numHashes; | ||
| numBuckets_ = numBuckets; | ||
| seed_ = seed; | ||
| hashSeeds_ = new long[numHashes]; | ||
| sketchArray_ = new long[numHashes * numBuckets]; | ||
| totalWeight_ = 0; | ||
| // Validate numHashes | ||
| if (numHashes <= 0) { | ||
| throw new SketchesArgumentException("Number of hash functions must be positive, got: " + numHashes); | ||
| } | ||
|
|
||
| // Validate numBuckets with clear mathematical justification | ||
| if (numBuckets <= 0) { | ||
| throw new SketchesArgumentException("Number of buckets must be positive, got: " + numBuckets); | ||
| } | ||
| if (numBuckets < 3) { | ||
| throw new SketchesArgumentException("Using fewer than 3 buckets incurs relative error greater than 1."); | ||
| throw new SketchesArgumentException("Number of buckets must be at least 3 to ensure relative error ≤ 1.0. " + | ||
| "With " + numBuckets + " buckets, relative error would be " + String.format("%.3f", Math.exp(1.0) / numBuckets)); | ||
|
freakyzoidberg marked this conversation as resolved.
|
||
| } | ||
|
|
||
| // Check for potential overflow in array size calculation | ||
| // Use long arithmetic to detect overflow before casting | ||
| final long totalSize = (long) numHashes * (long) numBuckets; | ||
| if (totalSize > Integer.MAX_VALUE) { | ||
| throw new SketchesArgumentException("Sketch array size would overflow: " + numHashes + " * " + numBuckets + | ||
| " = " + totalSize + " > " + Integer.MAX_VALUE); | ||
| } | ||
|
|
||
| // This check is to ensure later compatibility with a Java implementation whose maximum size can only | ||
| // be 2^31-1. We check only against 2^30 for simplicity. | ||
| if (numBuckets * numHashes >= 1 << 30) { | ||
| throw new SketchesArgumentException("These parameters generate a sketch that exceeds 2^30 elements. \n" + | ||
| "Try reducing either the number of buckets or the number of hash functions."); | ||
| if (totalSize >= (1L << 30)) { | ||
| throw new SketchesArgumentException("Sketch would require excessive memory: " + numHashes + " * " + numBuckets + | ||
| " = " + totalSize + " elements (~" + String.format("%.1f", totalSize * 8.0 / (1024 * 1024 * 1024)) + " GB). " + | ||
| "Consider reducing numHashes or numBuckets."); | ||
| } | ||
|
|
||
| numHashes_ = numHashes; | ||
| numBuckets_ = numBuckets; | ||
| seed_ = seed; | ||
| hashSeeds_ = new long[numHashes]; | ||
| sketchArray_ = new long[(int) totalSize]; | ||
| totalWeight_ = 0; | ||
|
|
||
| Random rand = new Random(seed); | ||
| for (int i = 0; i < numHashes; i++) { | ||
|
freakyzoidberg marked this conversation as resolved.
|
||
| hashSeeds_[i] = rand.nextLong(); | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Efficiently converts a long to byte array using thread-local MemorySegment with explicit endianness. | ||
| */ | ||
| private static byte[] longToBytes(final long value) { | ||
| final MemorySegment segment = LONG_SEGMENT.get(); | ||
| segment.set(JAVA_LONG_UNALIGNED, 0, value); | ||
| return segment.toArray(JAVA_BYTE); | ||
| } | ||
|
|
||
|
|
||
| private long[] getHashes(byte[] item) { | ||
| long[] updateLocations = new long[numHashes_]; | ||
|
|
||
|
|
@@ -171,8 +208,7 @@ public static int suggestNumBuckets(double relativeError) { | |
| * @param weight The weight of the item. | ||
| */ | ||
| public void update(final long item, final long weight) { | ||
| byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); | ||
| update(longByte, weight); | ||
| update(longToBytes(item), weight); | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -211,8 +247,7 @@ public void update(final byte[] item, final long weight) { | |
| * @return Estimated frequency. | ||
| */ | ||
| public long getEstimate(final long item) { | ||
| byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); | ||
| return getEstimate(longByte); | ||
| return getEstimate(longToBytes(item)); | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -241,8 +276,9 @@ public long getEstimate(final byte[] item) { | |
|
|
||
| long[] hashLocations = getHashes(item); | ||
| long res = sketchArray_[(int) hashLocations[0]]; | ||
| for (long h : hashLocations) { | ||
| res = Math.min(res, sketchArray_[(int) h]); | ||
| // Start from index 1 to avoid processing first element twice | ||
| for (int i = 1; i < hashLocations.length; i++) { | ||
| res = Math.min(res, sketchArray_[(int) hashLocations[i]]); | ||
| } | ||
|
|
||
| return res; | ||
|
|
@@ -254,8 +290,7 @@ public long getEstimate(final byte[] item) { | |
| * @return Upper bound of estimated frequency. | ||
| */ | ||
| public long getUpperBound(final long item) { | ||
| byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); | ||
| return getUpperBound(longByte); | ||
| return getUpperBound(longToBytes(item)); | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -291,8 +326,7 @@ public long getUpperBound(final byte[] item) { | |
| * @return Lower bound of estimated frequency. | ||
| */ | ||
| public long getLowerBound(final long item) { | ||
| byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); | ||
| return getLowerBound(longByte); | ||
| return getLowerBound(longToBytes(item)); | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -342,39 +376,62 @@ public void merge(final CountMinSketch other) { | |
| } | ||
|
|
||
| /** | ||
| * Serializes the sketch into the provided ByteBuffer. | ||
| * @param buf The ByteBuffer to write into. | ||
| * Returns the serialized size in bytes. | ||
| */ | ||
| private int getSerializedSizeBytes() { | ||
| final int preambleBytes = Family.COUNTMIN.getMinPreLongs() * Long.BYTES; | ||
| if (isEmpty()) { | ||
| return preambleBytes; | ||
| } | ||
| return preambleBytes + Long.BYTES + (sketchArray_.length * Long.BYTES); | ||
| } | ||
|
|
||
|
|
||
| /** | ||
| * Returns the sketch as a byte array. | ||
| */ | ||
| public void serialize(ByteArrayOutputStream buf) { | ||
| public byte[] toByteArray() { | ||
| final int serializedSizeBytes = getSerializedSizeBytes(); | ||
| final MemorySegment wseg = MemorySegment.ofArray(new byte[serializedSizeBytes]); | ||
|
|
||
| long offset = 0; | ||
|
|
||
| // Long 0 | ||
| final int preambleLongs = Family.COUNTMIN.getMinPreLongs(); | ||
| buf.write((byte) preambleLongs); | ||
| wseg.set(JAVA_BYTE, offset++, (byte) preambleLongs); | ||
| final int serialVersion = 1; | ||
| buf.write((byte) serialVersion); | ||
| wseg.set(JAVA_BYTE, offset++, (byte) serialVersion); | ||
| final int familyId = Family.COUNTMIN.getID(); | ||
| buf.write((byte) familyId); | ||
| wseg.set(JAVA_BYTE, offset++, (byte) familyId); | ||
| final int flagsByte = isEmpty() ? Flag.IS_EMPTY.mask() : 0; | ||
| buf.write((byte)flagsByte); | ||
| wseg.set(JAVA_BYTE, offset++, (byte) flagsByte); | ||
| final int NULL_32 = 0; | ||
| buf.writeBytes(ByteBuffer.allocate(4).putInt(NULL_32).array()); | ||
| wseg.set(JAVA_INT_UNALIGNED, offset, NULL_32); | ||
| offset += 4; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This (and following) is a use-case where the new datasketches.common.positional.PositionalSegment could be used. You might want to look at it. I used the term "positional" instead of "buffer", because "buffer" is a widely used generic term and doesn't convey what is actually going on. The PS would eliminate all of the offset tracking that you had to do by hand (with magic numbers). I had already migrated theta and tuple, but when I got to bloomfilter, I decided there was a real need for it -- so I paused and created it. It is already used in bloomfilter, tdigest, kll, and req. I intend to go back and use it in the rest of the library where it makes sense. And here it makes perfect sense, but this suggestion is optional. Nonetheless, if you decide not to use PS, I would recommend not using magic numbers for the offset adjustments and use the build-in static constants such as Integer.BYTES, Long.BYTES, etc. The code will become more obvious as to what you are doing and more robust. :-) Again, thank you for your work on this! |
||
|
|
||
| // Long 1 | ||
| buf.writeBytes(ByteBuffer.allocate(4).putInt(numBuckets_).array()); | ||
| buf.write(numHashes_); | ||
| wseg.set(JAVA_INT_UNALIGNED, offset, numBuckets_); | ||
| offset += 4; | ||
| wseg.set(JAVA_BYTE, offset++, numHashes_); | ||
| short hashSeed = Util.computeSeedHash(seed_); | ||
| buf.writeBytes(ByteBuffer.allocate(2).putShort(hashSeed).array()); | ||
| wseg.set(JAVA_SHORT_UNALIGNED, offset, hashSeed); | ||
| offset += 2; | ||
| final byte NULL_8 = 0; | ||
| buf.write(NULL_8); | ||
| wseg.set(JAVA_BYTE, offset++, NULL_8); | ||
|
|
||
| if (isEmpty()) { | ||
| return; | ||
| return wseg.toArray(JAVA_BYTE); | ||
| } | ||
|
|
||
| final byte[] totWeightByte = ByteBuffer.allocate(8).putLong(totalWeight_).array(); | ||
| buf.writeBytes(totWeightByte); | ||
| wseg.set(JAVA_LONG_UNALIGNED, offset, totalWeight_); | ||
| offset += 8; | ||
|
|
||
| for (long w: sketchArray_) { | ||
| buf.writeBytes(ByteBuffer.allocate(8).putLong(w).array()); | ||
| wseg.set(JAVA_LONG_UNALIGNED, offset, w); | ||
| offset += 8; | ||
| } | ||
|
|
||
| return wseg.toArray(JAVA_BYTE); | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -384,20 +441,22 @@ public void serialize(ByteArrayOutputStream buf) { | |
| * @return The deserialized CountMinSketch. | ||
| */ | ||
| public static CountMinSketch deserialize(final byte[] b, final long seed) { | ||
| ByteBuffer buf = ByteBuffer.allocate(b.length); | ||
| buf.put(b); | ||
| buf.flip(); | ||
|
|
||
| final byte preambleLongs = buf.get(); | ||
| final byte serialVersion = buf.get(); | ||
| final byte familyId = buf.get(); | ||
| final byte flagsByte = buf.get(); | ||
| final int NULL_32 = buf.getInt(); | ||
|
|
||
| final int numBuckets = buf.getInt(); | ||
| final byte numHashes = buf.get(); | ||
| final short seedHash = buf.getShort(); | ||
| final byte NULL_8 = buf.get(); | ||
| final MemorySegment buf = MemorySegment.ofArray(b); | ||
| long offset = 0; | ||
|
|
||
| final byte preambleLongs = buf.get(JAVA_BYTE, offset++); | ||
| final byte serialVersion = buf.get(JAVA_BYTE, offset++); | ||
| final byte familyId = buf.get(JAVA_BYTE, offset++); | ||
| final byte flagsByte = buf.get(JAVA_BYTE, offset++); | ||
| final int NULL_32 = buf.get(JAVA_INT_UNALIGNED, offset); | ||
| offset += 4; | ||
|
|
||
| final int numBuckets = buf.get(JAVA_INT_UNALIGNED, offset); | ||
| offset += 4; | ||
| final byte numHashes = buf.get(JAVA_BYTE, offset++); | ||
| final short seedHash = buf.get(JAVA_SHORT_UNALIGNED, offset); | ||
| offset += 2; | ||
| final byte NULL_8 = buf.get(JAVA_BYTE, offset++); | ||
|
|
||
|
freakyzoidberg marked this conversation as resolved.
|
||
| if (seedHash != Util.computeSeedHash(seed)) { | ||
| throw new SketchesArgumentException("Incompatible seed hashes: " + String.valueOf(seedHash) + ", " | ||
|
|
@@ -409,11 +468,13 @@ public static CountMinSketch deserialize(final byte[] b, final long seed) { | |
| if (empty) { | ||
| return cms; | ||
| } | ||
| long w = buf.getLong(); | ||
| long w = buf.get(JAVA_LONG_UNALIGNED, offset); | ||
| offset += 8; | ||
| cms.totalWeight_ = w; | ||
|
|
||
| for (int i = 0; i < cms.sketchArray_.length; i++) { | ||
| cms.sketchArray_[i] = buf.getLong(); | ||
| cms.sketchArray_[i] = buf.get(JAVA_LONG_UNALIGNED, offset); | ||
| offset += 8; | ||
| } | ||
|
|
||
| return cms; | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.