diff --git a/raft/src/main/scala/zio/raft/HMap.scala b/raft/src/main/scala/zio/raft/HMap.scala index 74eda00f..ac08ba29 100644 --- a/raft/src/main/scala/zio/raft/HMap.scala +++ b/raft/src/main/scala/zio/raft/HMap.scala @@ -129,6 +129,12 @@ final case class HMap[M <: Tuple](private val m: TreeMap[Array[Byte], Any] = // Lower bound: [length][prefix] - all keys with this prefix start here val lowerBound = Array(prefixLength.toByte) ++ prefixBytes + val upperBound = computePrefixUpperBound(lowerBound) + + (lowerBound, upperBound) + + // Helper: computes the lexicographic upper bound for a prefix + private[raft] def computePrefixUpperBound(prefixBytes: Array[Byte]): Array[Byte] = // Upper bound: Increment prefix bytes with carry propagation // Start from rightmost byte, find first byte that isn't 0xFF, increment it, zero rest val upperPrefixBytes = prefixBytes.clone() @@ -137,24 +143,27 @@ final case class HMap[M <: Tuple](private val m: TreeMap[Array[Byte], Any] = while carry && i >= 0 do if upperPrefixBytes(i) != 0xff.toByte then - // Found a byte that is already 0xFF, will zero it and propagate carry (continue) - if upperPrefixBytes(i) == 0xff.toByte then - upperPrefixBytes(i) = 0.toByte - else - // Found a byte that is not 0xFF, increment it and stop carry - upperPrefixBytes(i) = (upperPrefixBytes(i) + 1).toByte - carry = false + // Found a byte that is not 0xFF, increment it and stop carry + upperPrefixBytes(i) = (upperPrefixBytes(i) + 1).toByte + carry = false + else + // Found a byte that is 0xFF, will zero it and propagate carry (continue) + upperPrefixBytes(i) = 0.toByte i -= 1 - val upperBound = - if carry then - // All bytes were 0xFF - use next length value with empty prefix - // This is lexicographically after all keys with current prefix length - Array((prefixLength + 1).toByte) - else - Array(prefixLength.toByte) ++ upperPrefixBytes + if carry then + // All bytes were 0xFF, append zero to the string + Array.fill(prefixBytes.length)(0xff.toByte) ++ Array(0.toByte) + else + // Truncate any trailing zeros from the upperPrefixBytes for minimal upper bound representation + var end = upperPrefixBytes.length + while end > 0 && upperPrefixBytes(end - 1) == 0.toByte do + end -= 1 - (lowerBound, upperBound) + if end != upperPrefixBytes.length then + upperPrefixBytes.slice(0, end) + else + upperPrefixBytes /** Retrieve a value for the given prefix and key. * @@ -342,6 +351,100 @@ final case class HMap[M <: Tuple](private val m: TreeMap[Array[Byte], Any] = (logicalKey, v.asInstanceOf[ValueAt[M, P]]) } + /** Returns an iterator over (key, value) pairs for all entries whose compound key starts with the specified prefix + * and partial key, but only within the scope of the first component of the compound key. + * + * This is intended for use with compound keys, where you want to fetch all entries grouped by the first part of the + * compound key. The user specifies the first part of the key, and a "zero" value (usually empty string or zero-like + * value) for the second component. The method returns all keys beginning with that compound key prefix, but only + * within the same first key. + * + * For example, for compound keys like (namespace, userId), you can fetch all keys for a given namespace: + * hmap.rangeByCompoundKeyPrefix["users"]((namespace, "")) This will return all user records within the `namespace`, + * regardless of the second component value. + * + * IMPORTANT: For this method to work correctly, the `KeyLike` implementation for the compound key type must encode + * only the leading component(s) in the byte array when the trailing ("zero") component is empty. That is, when + * encoding a partial/compound key like (namespace, ""), the encoder must omit any length prefix or bytes for the + * "zero"/empty part — the resulting byte array must end after the first, non-empty component. Do NOT emit an + * explicit "length = 0" for the empty tail component. + * + * For decoding, the `KeyLike` instance should interpret a missing (truncated) trailing component in the byte array + * as the "zero" value (such as empty string, 0, or Nil), i.e., treat the absence of those bytes as an empty value. + * + * This is required because the range calculation increments the bytes length for the entire prefix (including the + * first key component). If the zero-part is ever encoded explicitly, it will instead increment that rather than just + * the first component, breaking correct grouping/iteration. + * + * Example: KeyLike instance for (String, String) that omits the second component if empty: + * {{{ + * given KeyLike[(String, String)] with + * def asBytes(key: (String, String)): Array[Byte] = + * val (first, second) = key + * val firstBytes = first.getBytes(StandardCharsets.UTF_8) + * if second.isEmpty then + * // Only encode the first component, omit the second part entirely + * Array(firstBytes.length.toByte) ++ firstBytes + * else + * val secondBytes = second.getBytes(StandardCharsets.UTF_8) + * Array(firstBytes.length.toByte) ++ firstBytes ++ Array(secondBytes.length.toByte) ++ secondBytes + * + * def fromBytes(bytes: Array[Byte]): (String, String) = + * // Decode first component + * val len1 = bytes(0) & 0xff + * val first = new String(bytes.slice(1, 1 + len1), StandardCharsets.UTF_8) + * // If there are no more bytes, treat second as "" + * if bytes.length == 1 + len1 then (first, "") + * else + * val len2Pos = 1 + len1 + * val len2 = bytes(len2Pos) & 0xff + * val second = new String(bytes.slice(len2Pos + 1, len2Pos + 1 + len2), StandardCharsets.UTF_8) + * (first, second) + * }}} + * + * @tparam P + * The prefix (must be present in the schema) + * @param partial + * The partial (compound) key, where the first component is provided and the trailing component is "zero" (empty + * string, 0, Nil, etc. depending on how KeyLike[KeyAt[M, P]] is implemented, but crucially, should be OMITTED in + * byte encoding) + * @return + * Iterator of (KeyType, ValueType) pairs matching the compound prefix + * + * @example + * {{{ + * type Schema = ("users", (String, String), UserData) *: EmptyTuple + * val hmap = HMap.empty[Schema] + * .updated["users"](("region_1", "userA"), UserData(...)) + * .updated["users"](("region_1", "userB"), UserData(...)) + * .updated["users"](("region_2", "userC"), UserData(...)) + * + * // To select all users in "region_1", you must implement KeyLike so that + * // ("region_1", "") is encoded as just the bytes of "region_1" (no length/marker for the second field). + * hmap.rangeByCompoundKeyPrefix["users"](("region_1", "")) // Returns both userA and userB + * }}} + * + * NOTE: The name `rangeByCompoundKeyPrefix` is suggested as it clarifies the intent and scope. Alternative names: + * `rangeByPrefixKey`, `rangeByPrimaryKey`. + */ + def rangeByCompoundKeyPrefix[P <: String & Singleton: ValueOf](partial: KeyAt[M, P])(using + c: Contains[M, P], + kl: KeyLike[KeyAt[M, P]] + ): Iterator[(KeyAt[M, P], ValueAt[M, P])] = + val prefixBytes = valueOf[P].getBytes(StandardCharsets.UTF_8) + val prefixLength = Array(prefixBytes.length.toByte) + val prefixWithLength = prefixLength ++ prefixBytes + + val fromKey = prefixWithLength ++ kl.asBytes(partial) + + // Compute the lexicographic upper bound for the given partial key (only increment the prefix + first component of the compound key) + val untilKey = computePrefixUpperBound(fromKey) + + m.range(fromKey, untilKey).iterator.map { case (k, v) => + val logicalKey = extractKey(k) + (logicalKey, v.asInstanceOf[ValueAt[M, P]]) + } + /** Check if any entry in the specified prefix satisfies the predicate. * * Uses the underlying TreeMap's range.exists for efficient short-circuit evaluation. Stops as soon as it finds a @@ -403,7 +506,7 @@ object HMap: * * Private since it's only used internally by HMap and is explicitly referenced where needed. */ - private given byteArrayOrdering: Ordering[Array[Byte]] = + private[raft] given byteArrayOrdering: Ordering[Array[Byte]] = Ordering.comparatorToOrdering(java.util.Arrays.compareUnsigned(_, _)) /** Create an empty HMap with the given schema. diff --git a/raft/src/test/scala/zio/raft/HMapPrefixRangeSpec.scala b/raft/src/test/scala/zio/raft/HMapPrefixRangeSpec.scala index 0376309d..6f2330b1 100644 --- a/raft/src/test/scala/zio/raft/HMapPrefixRangeSpec.scala +++ b/raft/src/test/scala/zio/raft/HMapPrefixRangeSpec.scala @@ -185,5 +185,36 @@ object HMapPrefixRangeSpec extends ZIOSpecDefault: assertTrue(compareUnsigned(upper, lower) > 0) && assertTrue(compareUnsigned(lower, upper) < 0) } + + test("computePrefixUpperBound works") { + val prefixBytes = Array[Byte]('t'.toByte, 'e'.toByte, 's'.toByte, 't'.toByte) + val hmap = HMap.empty[TestSchema] + val upper = hmap.computePrefixUpperBound(prefixBytes) + + assertTrue(HMap.byteArrayOrdering.compare(upper, prefixBytes) == 1) && + assertTrue(upper.length == 4) && + assertTrue(upper.sameElements(Array[Byte]('t'.toByte, 'e'.toByte, 's'.toByte, 'u'.toByte))) + } + + test("computePrefixUpperBound for max string 0xff") { + val prefixBytes = Array[Byte](0xff.toByte, 0xff.toByte) + val hmap = HMap.empty[TestSchema] + val upper = hmap.computePrefixUpperBound(prefixBytes) + + assertTrue(HMap.byteArrayOrdering.compare(upper, prefixBytes) == 1) && + assertTrue(upper.length == 3) && + assertTrue(upper.sameElements(Array[Byte](0xff.toByte, 0xff.toByte, 0x00.toByte))) + } + + test("truncate trailing zeros from computePrefixUpperBound") { + val prefixBytes = Array[Byte](0x01.toByte, 0xFF.toByte) + val hmap = HMap.empty[TestSchema] + val upper = hmap.computePrefixUpperBound(prefixBytes) + + assertTrue(HMap.byteArrayOrdering.compare(upper, prefixBytes) == 1) && + assertTrue(upper.length == 1) && + assertTrue(upper.sameElements(Array[Byte](0x02.toByte))) + } + } end HMapPrefixRangeSpec diff --git a/raft/src/test/scala/zio/raft/HMapRangeByCompoundKeyPrefixSpec.scala b/raft/src/test/scala/zio/raft/HMapRangeByCompoundKeyPrefixSpec.scala new file mode 100644 index 00000000..4a7c8e20 --- /dev/null +++ b/raft/src/test/scala/zio/raft/HMapRangeByCompoundKeyPrefixSpec.scala @@ -0,0 +1,108 @@ +package zio.raft + +import zio.test.* +import zio.test.Assertion.* +import java.nio.charset.StandardCharsets + +object HMapRangeByCompoundKeyPrefixSpec extends ZIOSpecDefault: + + // Compound key encoding: + // bytes = lengthOfFirstComponent (1 byte) ++ firstComponentUtf8 ++ [lengthOfSecondComponent (1 byte) ++ secondComponentUtf8] + // (The second component is omitted if empty.) + // This ensures that all keys that share the same first component are in a contiguous + // lexicographic range [firstComponentLength ++ firstComponentUtf8, ...), which is what + // rangeByCompoundKeyPrefix relies on by computing the upper bound using carry propagation, + // special handling for 0xFF bytes, and trailing zero truncation, as implemented in computePrefixUpperBound. + given HMap.KeyLike[(String, String)] with + def asBytes(key: (String, String)): Array[Byte] = + val (first, second) = key + val firstBytes = first.getBytes(StandardCharsets.UTF_8) + if second.isEmpty then + Array(firstBytes.length.toByte) ++ firstBytes + else + val secondBytes = second.getBytes(StandardCharsets.UTF_8) + Array(firstBytes.length.toByte) ++ firstBytes ++ Array(secondBytes.length.toByte) ++ secondBytes + + def fromBytes(bytes: Array[Byte]): (String, String) = + val len1 = bytes(0) & 0xff + val first = new String(bytes.slice(1, 1 + len1), StandardCharsets.UTF_8) + if bytes.length == 1 + len1 then (first, "") + else + val len2Pos = 1 + len1 + val len2 = bytes(len2Pos) & 0xff + val second = new String(bytes.slice(len2Pos + 1, len2Pos + 1 + len2), StandardCharsets.UTF_8) + (first, second) + + type Schema = ("users", (String, String), Int) *: EmptyTuple + + def spec = suiteAll("HMap.rangeByCompoundKeyPrefix") { + + test("computePrefixUpperBound works for compound key prefix") { + val prefix = ("r1", "") + val key = ("r1", "a") + val hmap = HMap.empty[Schema] + val keyBytes = summon[HMap.KeyLike[(String, String)]].asBytes(key) + val upper = hmap.computePrefixUpperBound(summon[HMap.KeyLike[(String, String)]].asBytes(prefix)) + + + assertTrue(HMap.byteArrayOrdering.compare(upper, keyBytes) > 0) + } + + test("returns all entries that share the same first component only") { + val hmap = + HMap.empty[Schema] + .updated["users"](("r1", "a"), 1) + .updated["users"](("r1", "b"), 2) + .updated["users"](("r1", "c"), 3) + .updated["users"](("r2", "x"), 10) + .updated["users"](("r3", "y"), 20) + + val results = hmap.rangeByCompoundKeyPrefix["users"](("r1", "")).toList + val keys = results.map(_._1) + val values = results.map(_._2) + + assertTrue(results.length == 3) && + assertTrue(keys.toSet == Set(("r1", "a"), ("r1", "b"), ("r1", "c"))) && + assertTrue(values.toSet == Set(1, 2, 3)) + } + + test("includes empty-second-component key and excludes other first components") { + val hmap = + HMap.empty[Schema] + .updated["users"](("ns", ""), 0) + .updated["users"](("ns", "k1"), 1) + .updated["users"](("ns", "k2"), 2) + .updated["users"](("ns2", ""), 100) + .updated["users"](("ns2", "k3"), 101) + + val nsResults = hmap.rangeByCompoundKeyPrefix["users"](("ns", "")).toList + val nsKeys = nsResults.map(_._1).toSet + val nsValues = nsResults.map(_._2).toSet + + val ns2Results = hmap.rangeByCompoundKeyPrefix["users"](("ns2", "")).toList + val ns2Keys = ns2Results.map(_._1).toSet + val ns2Values = ns2Results.map(_._2).toSet + + assertTrue(nsKeys == Set(("ns", ""), ("ns", "k1"), ("ns", "k2"))) && + assertTrue(nsValues == Set(0, 1, 2)) && + assertTrue(ns2Keys == Set(("ns2", ""), ("ns2", "k3"))) && + assertTrue(ns2Values == Set(100, 101)) + } + + test("works with unicode in first component and multiple seconds") { + val first = "régiön-𝟙" // unicode characters + val hmap = + HMap.empty[Schema] + .updated["users"]((first, "α"), 5) + .updated["users"]((first, "β"), 6) + .updated["users"]((first, "γ"), 7) + .updated["users"](("other", "δ"), 8) + + val results = hmap.rangeByCompoundKeyPrefix["users"]((first, "")).toList + + assertTrue(results.length == 3) && + assertTrue(results.map(_._1).toSet == Set((first, "α"), (first, "β"), (first, "γ"))) && + assertTrue(results.map(_._2).toSet == Set(5, 6, 7)) + } + } +end HMapRangeByCompoundKeyPrefixSpec