Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 119 additions & 16 deletions raft/src/main/scala/zio/raft/HMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ final case class HMap[M <: Tuple](private val m: TreeMap[Array[Byte], Any] =
// Lower bound: [length][prefix] - all keys with this prefix start here
val lowerBound = Array(prefixLength.toByte) ++ prefixBytes

val upperBound = computePrefixUpperBound(lowerBound)

(lowerBound, upperBound)

// Helper: computes the lexicographic upper bound for a prefix
private[raft] def computePrefixUpperBound(prefixBytes: Array[Byte]): Array[Byte] =
// Upper bound: Increment prefix bytes with carry propagation
// Start from rightmost byte, find first byte that isn't 0xFF, increment it, zero rest
val upperPrefixBytes = prefixBytes.clone()
Expand All @@ -137,24 +143,27 @@ final case class HMap[M <: Tuple](private val m: TreeMap[Array[Byte], Any] =

while carry && i >= 0 do
if upperPrefixBytes(i) != 0xff.toByte then
// Found a byte that is already 0xFF, will zero it and propagate carry (continue)
if upperPrefixBytes(i) == 0xff.toByte then
upperPrefixBytes(i) = 0.toByte
else
// Found a byte that is not 0xFF, increment it and stop carry
upperPrefixBytes(i) = (upperPrefixBytes(i) + 1).toByte
carry = false
// Found a byte that is not 0xFF, increment it and stop carry
upperPrefixBytes(i) = (upperPrefixBytes(i) + 1).toByte
carry = false
else
// Found a byte that is 0xFF, will zero it and propagate carry (continue)
upperPrefixBytes(i) = 0.toByte
i -= 1

val upperBound =
if carry then
// All bytes were 0xFF - use next length value with empty prefix
// This is lexicographically after all keys with current prefix length
Array((prefixLength + 1).toByte)
else
Array(prefixLength.toByte) ++ upperPrefixBytes
if carry then
// All bytes were 0xFF, append zero to the string
Array.fill(prefixBytes.length)(0xff.toByte) ++ Array(0.toByte)
else
// Truncate any trailing zeros from the upperPrefixBytes for minimal upper bound representation
var end = upperPrefixBytes.length
while end > 0 && upperPrefixBytes(end - 1) == 0.toByte do
end -= 1

(lowerBound, upperBound)
if end != upperPrefixBytes.length then
upperPrefixBytes.slice(0, end)
else
upperPrefixBytes

/** Retrieve a value for the given prefix and key.
*
Expand Down Expand Up @@ -342,6 +351,100 @@ final case class HMap[M <: Tuple](private val m: TreeMap[Array[Byte], Any] =
(logicalKey, v.asInstanceOf[ValueAt[M, P]])
}

/** Returns an iterator over (key, value) pairs for all entries whose compound key starts with the specified prefix
* and partial key, but only within the scope of the first component of the compound key.
*
* This is intended for use with compound keys, where you want to fetch all entries grouped by the first part of the
* compound key. The user specifies the first part of the key, and a "zero" value (usually empty string or zero-like
* value) for the second component. The method returns all keys beginning with that compound key prefix, but only
* within the same first key.
*
* For example, for compound keys like (namespace, userId), you can fetch all keys for a given namespace:
* hmap.rangeByCompoundKeyPrefix["users"]((namespace, "")) This will return all user records within the `namespace`,
* regardless of the second component value.
*
* IMPORTANT: For this method to work correctly, the `KeyLike` implementation for the compound key type must encode
* only the leading component(s) in the byte array when the trailing ("zero") component is empty. That is, when
* encoding a partial/compound key like (namespace, ""), the encoder must omit any length prefix or bytes for the
* "zero"/empty part — the resulting byte array must end after the first, non-empty component. Do NOT emit an
* explicit "length = 0" for the empty tail component.
*
* For decoding, the `KeyLike` instance should interpret a missing (truncated) trailing component in the byte array
* as the "zero" value (such as empty string, 0, or Nil), i.e., treat the absence of those bytes as an empty value.
*
* This is required because the range calculation increments the bytes length for the entire prefix (including the
* first key component). If the zero-part is ever encoded explicitly, it will instead increment that rather than just
* the first component, breaking correct grouping/iteration.
*
* Example: KeyLike instance for (String, String) that omits the second component if empty:
* {{{
* given KeyLike[(String, String)] with
* def asBytes(key: (String, String)): Array[Byte] =
* val (first, second) = key
* val firstBytes = first.getBytes(StandardCharsets.UTF_8)
* if second.isEmpty then
* // Only encode the first component, omit the second part entirely
* Array(firstBytes.length.toByte) ++ firstBytes
* else
* val secondBytes = second.getBytes(StandardCharsets.UTF_8)
* Array(firstBytes.length.toByte) ++ firstBytes ++ Array(secondBytes.length.toByte) ++ secondBytes
*
* def fromBytes(bytes: Array[Byte]): (String, String) =
* // Decode first component
* val len1 = bytes(0) & 0xff
* val first = new String(bytes.slice(1, 1 + len1), StandardCharsets.UTF_8)
* // If there are no more bytes, treat second as ""
* if bytes.length == 1 + len1 then (first, "")
* else
* val len2Pos = 1 + len1
* val len2 = bytes(len2Pos) & 0xff
* val second = new String(bytes.slice(len2Pos + 1, len2Pos + 1 + len2), StandardCharsets.UTF_8)
* (first, second)
* }}}
*
* @tparam P
* The prefix (must be present in the schema)
* @param partial
* The partial (compound) key, where the first component is provided and the trailing component is "zero" (empty
* string, 0, Nil, etc. depending on how KeyLike[KeyAt[M, P]] is implemented, but crucially, should be OMITTED in
* byte encoding)
* @return
* Iterator of (KeyType, ValueType) pairs matching the compound prefix
*
* @example
* {{{
* type Schema = ("users", (String, String), UserData) *: EmptyTuple
* val hmap = HMap.empty[Schema]
* .updated["users"](("region_1", "userA"), UserData(...))
* .updated["users"](("region_1", "userB"), UserData(...))
* .updated["users"](("region_2", "userC"), UserData(...))
*
* // To select all users in "region_1", you must implement KeyLike so that
* // ("region_1", "") is encoded as just the bytes of "region_1" (no length/marker for the second field).
* hmap.rangeByCompoundKeyPrefix["users"](("region_1", "")) // Returns both userA and userB
* }}}
*
* NOTE: The name `rangeByCompoundKeyPrefix` is suggested as it clarifies the intent and scope. Alternative names:
* `rangeByPrefixKey`, `rangeByPrimaryKey`.
*/
def rangeByCompoundKeyPrefix[P <: String & Singleton: ValueOf](partial: KeyAt[M, P])(using
c: Contains[M, P],
kl: KeyLike[KeyAt[M, P]]
): Iterator[(KeyAt[M, P], ValueAt[M, P])] =
val prefixBytes = valueOf[P].getBytes(StandardCharsets.UTF_8)
val prefixLength = Array(prefixBytes.length.toByte)
val prefixWithLength = prefixLength ++ prefixBytes

val fromKey = prefixWithLength ++ kl.asBytes(partial)

// Compute the lexicographic upper bound for the given partial key (only increment the prefix + first component of the compound key)
val untilKey = computePrefixUpperBound(fromKey)

m.range(fromKey, untilKey).iterator.map { case (k, v) =>
val logicalKey = extractKey(k)
(logicalKey, v.asInstanceOf[ValueAt[M, P]])
}

/** Check if any entry in the specified prefix satisfies the predicate.
*
* Uses the underlying TreeMap's range.exists for efficient short-circuit evaluation. Stops as soon as it finds a
Expand Down Expand Up @@ -403,7 +506,7 @@ object HMap:
*
* Private since it's only used internally by HMap and is explicitly referenced where needed.
*/
private given byteArrayOrdering: Ordering[Array[Byte]] =
private[raft] given byteArrayOrdering: Ordering[Array[Byte]] =
Ordering.comparatorToOrdering(java.util.Arrays.compareUnsigned(_, _))

/** Create an empty HMap with the given schema.
Expand Down
31 changes: 31 additions & 0 deletions raft/src/test/scala/zio/raft/HMapPrefixRangeSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -185,5 +185,36 @@ object HMapPrefixRangeSpec extends ZIOSpecDefault:
assertTrue(compareUnsigned(upper, lower) > 0) &&
assertTrue(compareUnsigned(lower, upper) < 0)
}

test("computePrefixUpperBound works") {
val prefixBytes = Array[Byte]('t'.toByte, 'e'.toByte, 's'.toByte, 't'.toByte)
val hmap = HMap.empty[TestSchema]
val upper = hmap.computePrefixUpperBound(prefixBytes)

assertTrue(HMap.byteArrayOrdering.compare(upper, prefixBytes) == 1) &&
assertTrue(upper.length == 4) &&
assertTrue(upper.sameElements(Array[Byte]('t'.toByte, 'e'.toByte, 's'.toByte, 'u'.toByte)))
}

test("computePrefixUpperBound for max string 0xff") {
val prefixBytes = Array[Byte](0xff.toByte, 0xff.toByte)
val hmap = HMap.empty[TestSchema]
val upper = hmap.computePrefixUpperBound(prefixBytes)

assertTrue(HMap.byteArrayOrdering.compare(upper, prefixBytes) == 1) &&
assertTrue(upper.length == 3) &&
assertTrue(upper.sameElements(Array[Byte](0xff.toByte, 0xff.toByte, 0x00.toByte)))
}

test("truncate trailing zeros from computePrefixUpperBound") {
val prefixBytes = Array[Byte](0x01.toByte, 0xFF.toByte)
val hmap = HMap.empty[TestSchema]
val upper = hmap.computePrefixUpperBound(prefixBytes)

assertTrue(HMap.byteArrayOrdering.compare(upper, prefixBytes) == 1) &&
assertTrue(upper.length == 1) &&
assertTrue(upper.sameElements(Array[Byte](0x02.toByte)))
}

}
end HMapPrefixRangeSpec
108 changes: 108 additions & 0 deletions raft/src/test/scala/zio/raft/HMapRangeByCompoundKeyPrefixSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package zio.raft

import zio.test.*
import zio.test.Assertion.*
import java.nio.charset.StandardCharsets

object HMapRangeByCompoundKeyPrefixSpec extends ZIOSpecDefault:

// Compound key encoding:
// bytes = lengthOfFirstComponent (1 byte) ++ firstComponentUtf8 ++ [lengthOfSecondComponent (1 byte) ++ secondComponentUtf8]
// (The second component is omitted if empty.)
// This ensures that all keys that share the same first component are in a contiguous
// lexicographic range [firstComponentLength ++ firstComponentUtf8, ...), which is what
// rangeByCompoundKeyPrefix relies on by computing the upper bound using carry propagation,
// special handling for 0xFF bytes, and trailing zero truncation, as implemented in computePrefixUpperBound.
given HMap.KeyLike[(String, String)] with
def asBytes(key: (String, String)): Array[Byte] =
val (first, second) = key
val firstBytes = first.getBytes(StandardCharsets.UTF_8)
if second.isEmpty then
Array(firstBytes.length.toByte) ++ firstBytes
else
val secondBytes = second.getBytes(StandardCharsets.UTF_8)
Array(firstBytes.length.toByte) ++ firstBytes ++ Array(secondBytes.length.toByte) ++ secondBytes

def fromBytes(bytes: Array[Byte]): (String, String) =
val len1 = bytes(0) & 0xff
val first = new String(bytes.slice(1, 1 + len1), StandardCharsets.UTF_8)
if bytes.length == 1 + len1 then (first, "")
else
val len2Pos = 1 + len1
val len2 = bytes(len2Pos) & 0xff
val second = new String(bytes.slice(len2Pos + 1, len2Pos + 1 + len2), StandardCharsets.UTF_8)
(first, second)

type Schema = ("users", (String, String), Int) *: EmptyTuple

def spec = suiteAll("HMap.rangeByCompoundKeyPrefix") {

test("computePrefixUpperBound works for compound key prefix") {
val prefix = ("r1", "")
val key = ("r1", "a")
val hmap = HMap.empty[Schema]
val keyBytes = summon[HMap.KeyLike[(String, String)]].asBytes(key)
val upper = hmap.computePrefixUpperBound(summon[HMap.KeyLike[(String, String)]].asBytes(prefix))


assertTrue(HMap.byteArrayOrdering.compare(upper, keyBytes) > 0)
}

test("returns all entries that share the same first component only") {
val hmap =
HMap.empty[Schema]
.updated["users"](("r1", "a"), 1)
.updated["users"](("r1", "b"), 2)
.updated["users"](("r1", "c"), 3)
.updated["users"](("r2", "x"), 10)
.updated["users"](("r3", "y"), 20)

val results = hmap.rangeByCompoundKeyPrefix["users"](("r1", "")).toList
val keys = results.map(_._1)
val values = results.map(_._2)

assertTrue(results.length == 3) &&
assertTrue(keys.toSet == Set(("r1", "a"), ("r1", "b"), ("r1", "c"))) &&
assertTrue(values.toSet == Set(1, 2, 3))
}

test("includes empty-second-component key and excludes other first components") {
val hmap =
HMap.empty[Schema]
.updated["users"](("ns", ""), 0)
.updated["users"](("ns", "k1"), 1)
.updated["users"](("ns", "k2"), 2)
.updated["users"](("ns2", ""), 100)
.updated["users"](("ns2", "k3"), 101)

val nsResults = hmap.rangeByCompoundKeyPrefix["users"](("ns", "")).toList
val nsKeys = nsResults.map(_._1).toSet
val nsValues = nsResults.map(_._2).toSet

val ns2Results = hmap.rangeByCompoundKeyPrefix["users"](("ns2", "")).toList
val ns2Keys = ns2Results.map(_._1).toSet
val ns2Values = ns2Results.map(_._2).toSet

assertTrue(nsKeys == Set(("ns", ""), ("ns", "k1"), ("ns", "k2"))) &&
assertTrue(nsValues == Set(0, 1, 2)) &&
assertTrue(ns2Keys == Set(("ns2", ""), ("ns2", "k3"))) &&
assertTrue(ns2Values == Set(100, 101))
}

test("works with unicode in first component and multiple seconds") {
val first = "régiön-𝟙" // unicode characters
val hmap =
HMap.empty[Schema]
.updated["users"]((first, "α"), 5)
.updated["users"]((first, "β"), 6)
.updated["users"]((first, "γ"), 7)
.updated["users"](("other", "δ"), 8)

val results = hmap.rangeByCompoundKeyPrefix["users"]((first, "")).toList

assertTrue(results.length == 3) &&
assertTrue(results.map(_._1).toSet == Set((first, "α"), (first, "β"), (first, "γ"))) &&
assertTrue(results.map(_._2).toSet == Set(5, 6, 7))
}
}
end HMapRangeByCompoundKeyPrefixSpec