diff --git a/storage/src/index/benches/bench.rs b/storage/src/index/benches/bench.rs index 16d6810b6aa..effcbb3529c 100644 --- a/storage/src/index/benches/bench.rs +++ b/storage/src/index/benches/bench.rs @@ -9,6 +9,7 @@ mod hashmap_insert; mod hashmap_insert_fixed; mod hashmap_iteration; mod insert; +mod insert_and_prune; mod lookup; mod lookup_miss; @@ -51,6 +52,7 @@ criterion_main!( hashmap_insert_fixed::benches, hashmap_insert::benches, insert::benches, + insert_and_prune::benches, lookup::benches, lookup_miss::benches, ); diff --git a/storage/src/index/benches/insert_and_prune.rs b/storage/src/index/benches/insert_and_prune.rs new file mode 100644 index 00000000000..15c1895aefd --- /dev/null +++ b/storage/src/index/benches/insert_and_prune.rs @@ -0,0 +1,60 @@ +use super::DummyMetrics; +use commonware_cryptography::{Hasher, Sha256}; +use commonware_storage::{ + index::{unordered, Unordered}, + translator::FourCap, +}; +use criterion::{criterion_group, Criterion}; +use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng}; +use std::time::{Duration, Instant}; + +#[cfg(not(full_bench))] +const N_ITEMS: [usize; 2] = [10_000, 50_000]; +#[cfg(full_bench)] +const N_ITEMS: [usize; 4] = [10_000, 50_000, 100_000, 500_000]; + +fn bench_insert_and_prune(c: &mut Criterion) { + for items in N_ITEMS { + let mut rng = StdRng::seed_from_u64(0); + let mut kvs = Vec::with_capacity(items); + for i in 0..items { + kvs.push((Sha256::hash(&i.to_be_bytes()), i as u64)); + } + kvs.shuffle(&mut rng); + + c.bench_function(&format!("{}/items={items}", module_path!()), |b| { + let kvs_data = kvs.clone(); + b.iter_custom(move |iters| { + let mut total = Duration::ZERO; + for _ in 0..iters { + let mut index = unordered::Index::new(DummyMetrics, FourCap); + total += run_benchmark(&mut index, &kvs_data); + } + total + }); + }); + } +} + +fn run_benchmark>( + index: &mut I, + kvs: &[(::Digest, u64)], +) -> Duration { + // Seed the index with initial values. + for (k, v) in kvs { + index.insert(k, *v); + } + + // Overwrite every key using insert_and_prune: prune the old value, insert the new one. + let start = Instant::now(); + for (k, v) in kvs { + index.insert_and_prune(k, *v + 1, |old| *old == *v); + } + start.elapsed() +} + +criterion_group! { + name = benches; + config = Criterion::default().sample_size(10); + targets = bench_insert_and_prune +} diff --git a/storage/src/index/mod.rs b/storage/src/index/mod.rs index a9551545789..2af9ca7cce7 100644 --- a/storage/src/index/mod.rs +++ b/storage/src/index/mod.rs @@ -37,8 +37,8 @@ pub mod unordered; /// /// - Must call `next()` before `update()`, `insert()`, or `delete()` to establish a valid position. /// - Once `next()` returns `None`, only `insert()` can be called. -/// - Dropping the `Cursor` automatically restores the list structure by reattaching any detached -/// `next` nodes. +/// - The cursor mutates the linked list in place. If the sole element is deleted, dropping the +/// cursor removes the map entry. /// /// _If you don't need advanced functionality, just use `insert()`, `insert_and_prune()`, or /// `remove()` from [Unordered] instead._ @@ -53,8 +53,8 @@ pub trait Cursor: Send + Sync { /// If after `insert()`, the next active item is the item after the inserted item. If after /// `delete()`, the next active item is the item after the deleted item. /// - /// Handles transitions between phases and adjusts for deletions. Returns `None` when the list - /// is exhausted. It is safe to call `next()` even after it returns `None`. + /// Advances through cursor states and adjusts for deletions. Returns `None` when the list is + /// exhausted. It is safe to call `next()` even after it returns `None`. #[allow(clippy::should_implement_trait)] fn next(&mut self) -> Option<&Self::Value>; @@ -66,7 +66,7 @@ pub trait Cursor: Send + Sync { /// Updates the value at the current position in the iteration. /// - /// Panics if called before `next()` or after iteration is complete (`Status::Done` phase). + /// Panics if called before `next()` or after iteration is complete. fn update(&mut self, value: Self::Value); /// Removes anything in the cursor that satisfies the predicate. @@ -131,7 +131,7 @@ pub trait Unordered: Send + Sync { value: Self::Value, ) -> Option>; - /// Inserts a new value at the current position. + /// Inserts a new value for the translated key. fn insert(&mut self, key: &[u8], value: Self::Value); /// Insert a value at the given translated key, and prune any values that are no longer valid. @@ -145,7 +145,11 @@ pub trait Unordered: Send + Sync { ); /// Remove all values associated with a translated key that match `predicate`. - fn prune(&mut self, key: &[u8], predicate: impl Fn(&Self::Value) -> bool); + fn prune(&mut self, key: &[u8], predicate: impl Fn(&Self::Value) -> bool) { + if let Some(mut cursor) = self.get_mut(key) { + cursor.prune(&predicate); + } + } /// Remove all values associated with a translated key. fn remove(&mut self, key: &[u8]); @@ -239,15 +243,15 @@ mod tests { index.insert(key, 3); assert_eq!(index.keys(), 1); - // Check that the values are in the correct order - assert_eq!(index.get(key).copied().collect::>(), vec![1, 3, 2]); + // Check that the values are in the expected newest-first order. + assert_eq!(index.get(key).copied().collect::>(), vec![3, 2, 1]); // Ensure cursor terminates { let mut cursor = index.get_mut(key).unwrap(); - assert_eq!(*cursor.next().unwrap(), 1); assert_eq!(*cursor.next().unwrap(), 3); assert_eq!(*cursor.next().unwrap(), 2); + assert_eq!(*cursor.next().unwrap(), 1); assert!(cursor.next().is_none()); } @@ -255,7 +259,7 @@ mod tests { index.insert(key, 3); index.insert(key, 4); index.prune(key, |i| *i == 3); - assert_eq!(index.get(key).copied().collect::>(), vec![1, 4, 2]); + assert_eq!(index.get(key).copied().collect::>(), vec![4, 2, 1]); index.prune(key, |_| true); // Try removing all of a keys values. assert_eq!( @@ -477,16 +481,16 @@ mod tests { index.insert(b"ab", 2); index.insert(b"abc", 3); - assert_eq!(index.get(b"ab").copied().collect::>(), vec![2, 3]); - assert_eq!(index.get(b"abc").copied().collect::>(), vec![2, 3]); + assert_eq!(index.get(b"ab").copied().collect::>(), vec![3, 2]); + assert_eq!(index.get(b"abc").copied().collect::>(), vec![3, 2]); index.insert(b"ab", 4); - assert_eq!(index.get(b"ab").copied().collect::>(), vec![2, 4, 3]); + assert_eq!(index.get(b"ab").copied().collect::>(), vec![4, 3, 2]); assert_eq!(index.keys(), 2); assert_eq!(index.items(), 4); index.prune(b"ab", |v| *v == 4); - assert_eq!(index.get(b"ab").copied().collect::>(), vec![2, 3]); + assert_eq!(index.get(b"ab").copied().collect::>(), vec![3, 2]); assert_eq!(index.keys(), 2); assert_eq!(index.items(), 3); @@ -539,7 +543,7 @@ mod tests { index.insert(b"key", 3); assert_eq!( index.get(b"key").copied().collect::>(), - vec![1, 3, 2] + vec![3, 2, 1] ); } @@ -581,7 +585,7 @@ mod tests { index.insert(b"key", 2); index.insert(b"key", 3); index.prune(b"key", |v| *v == 2); - assert_eq!(index.get(b"key").copied().collect::>(), vec![1, 3]); + assert_eq!(index.get(b"key").copied().collect::>(), vec![3, 1]); index.prune(b"key", |v| *v == 1); assert_eq!(index.get(b"key").copied().collect::>(), vec![3]); } @@ -685,7 +689,7 @@ mod tests { } assert_eq!( index.get(b"key").copied().collect::>(), - vec![11, 13, 12] + vec![13, 12, 11] ); } @@ -729,18 +733,18 @@ mod tests { index.insert(b"key", 4); assert_eq!( index.get(b"key").copied().collect::>(), - vec![1, 4, 3, 2] + vec![4, 3, 2, 1] ); { let mut cursor = index.get_mut(b"key").unwrap(); - assert_eq!(*cursor.next().unwrap(), 1); assert_eq!(*cursor.next().unwrap(), 4); + assert_eq!(*cursor.next().unwrap(), 3); let _ = cursor.next().unwrap(); cursor.update(99); } assert_eq!( index.get(b"key").copied().collect::>(), - vec![1, 4, 99, 2] + vec![4, 3, 99, 1] ); } @@ -778,59 +782,59 @@ mod tests { } fn run_index_remove_through_iterator>(index: &mut I) { - index.insert(b"key", 1); - index.insert(b"key", 2); - index.insert(b"key", 3); - index.insert(b"key", 4); + index.insert(b"key", 10); + index.insert(b"key", 20); + index.insert(b"key", 30); + index.insert(b"key", 40); assert_eq!( index.get(b"key").copied().collect::>(), - vec![1, 4, 3, 2] + vec![40, 30, 20, 10] ); assert_eq!(index.pruned(), 0); { let mut cursor = index.get_mut(b"key").unwrap(); - assert_eq!(*cursor.next().unwrap(), 1); + assert_eq!(*cursor.next().unwrap(), 40); cursor.delete(); } assert_eq!(index.pruned(), 1); assert_eq!( index.get(b"key").copied().collect::>(), - vec![4, 3, 2] + vec![30, 20, 10] ); - index.insert(b"key", 1); + index.insert(b"key", 50); assert_eq!( index.get(b"key").copied().collect::>(), - vec![4, 1, 3, 2] + vec![50, 30, 20, 10] ); { let mut cursor = index.get_mut(b"key").unwrap(); - assert_eq!(*cursor.next().unwrap(), 4); - assert_eq!(*cursor.next().unwrap(), 1); - assert_eq!(*cursor.next().unwrap(), 3); + assert_eq!(*cursor.next().unwrap(), 50); + assert_eq!(*cursor.next().unwrap(), 30); + assert_eq!(*cursor.next().unwrap(), 20); cursor.delete(); } assert_eq!(index.pruned(), 2); assert_eq!( index.get(b"key").copied().collect::>(), - vec![4, 1, 2] + vec![50, 30, 10] ); - index.insert(b"key", 3); + index.insert(b"key", 60); assert_eq!( index.get(b"key").copied().collect::>(), - vec![4, 3, 1, 2] + vec![60, 50, 30, 10] ); { let mut cursor = index.get_mut(b"key").unwrap(); - assert_eq!(*cursor.next().unwrap(), 4); - assert_eq!(*cursor.next().unwrap(), 3); - assert_eq!(*cursor.next().unwrap(), 1); - assert_eq!(*cursor.next().unwrap(), 2); + assert_eq!(*cursor.next().unwrap(), 60); + assert_eq!(*cursor.next().unwrap(), 50); + assert_eq!(*cursor.next().unwrap(), 30); + assert_eq!(*cursor.next().unwrap(), 10); cursor.delete(); } assert_eq!(index.pruned(), 3); assert_eq!( index.get(b"key").copied().collect::>(), - vec![4, 3, 1] + vec![60, 50, 30] ); index.remove(b"key"); assert_eq!(index.keys(), 0); @@ -897,8 +901,8 @@ mod tests { } index.insert(b"key", 100); let mut iter = index.get(b"key"); - assert_eq!(*iter.next().unwrap(), 1); assert_eq!(*iter.next().unwrap(), 100); + assert_eq!(*iter.next().unwrap(), 1); assert_eq!(*iter.next().unwrap(), 42); assert_eq!(*iter.next().unwrap(), 3); assert!(iter.next().is_none()); @@ -987,14 +991,14 @@ mod tests { } { let mut cursor = index.get_mut(b"key").unwrap(); - assert_eq!(*cursor.next().unwrap(), 0); - cursor.delete(); assert_eq!(*cursor.next().unwrap(), 3); cursor.delete(); assert_eq!(*cursor.next().unwrap(), 2); cursor.delete(); assert_eq!(*cursor.next().unwrap(), 1); cursor.delete(); + assert_eq!(*cursor.next().unwrap(), 0); + cursor.delete(); assert_eq!(cursor.next(), None); cursor.insert(4); assert_eq!(cursor.next(), None); @@ -1121,6 +1125,50 @@ mod tests { }); } + fn run_index_insert_and_prune_vacant_pruned>(index: &mut I) { + index.insert_and_prune(b"key", 1u64, |_| true); + assert_eq!( + index.get(b"key").copied().collect::>(), + Vec::::new() + ); + assert_eq!(index.items(), 0); + assert_eq!(index.keys(), 0); + assert_eq!(index.pruned(), 0); + } + + #[test_traced] + fn test_hash_index_insert_and_prune_vacant_pruned() { + let runner = deterministic::Runner::default(); + runner.start(|context| async move { + let mut index = new_unordered(context); + run_index_insert_and_prune_vacant_pruned(&mut index); + }); + } + + #[test_traced] + fn test_ordered_index_insert_and_prune_vacant_pruned() { + let runner = deterministic::Runner::default(); + runner.start(|context| async move { + let mut index = new_ordered(context); + run_index_insert_and_prune_vacant_pruned(&mut index); + }); + } + + #[test_traced] + fn test_partitioned_index_insert_and_prune_vacant_pruned() { + let runner = deterministic::Runner::default(); + runner.start(|context| async move { + { + let mut index = new_partitioned_unordered(context.child("unordered")); + run_index_insert_and_prune_vacant_pruned(&mut index); + } + { + let mut index = new_partitioned_ordered(context.child("ordered")); + run_index_insert_and_prune_vacant_pruned(&mut index); + } + }); + } + fn run_index_insert_and_prune_replace_one>(index: &mut I) { index.insert(b"key", 1u64); index.insert_and_prune(b"key", 2u64, |v| *v == 1); @@ -1295,13 +1343,13 @@ mod tests { } { let mut cursor = index.get_mut(b"key").unwrap(); - assert_eq!(*cursor.next().unwrap(), 0); assert_eq!(*cursor.next().unwrap(), 3); - cursor.delete(); assert_eq!(*cursor.next().unwrap(), 2); cursor.delete(); + assert_eq!(*cursor.next().unwrap(), 1); + cursor.delete(); } - assert_eq!(index.get(b"key").copied().collect::>(), vec![0, 1]); + assert_eq!(index.get(b"key").copied().collect::>(), vec![3, 0]); } #[test_traced] @@ -1343,14 +1391,14 @@ mod tests { } { let mut cursor = index.get_mut(b"key").unwrap(); - assert_eq!(*cursor.next().unwrap(), 0); - cursor.delete(); assert_eq!(*cursor.next().unwrap(), 3); cursor.delete(); assert_eq!(*cursor.next().unwrap(), 2); cursor.delete(); assert_eq!(*cursor.next().unwrap(), 1); cursor.delete(); + assert_eq!(*cursor.next().unwrap(), 0); + cursor.delete(); assert_eq!(cursor.next(), None); } assert_eq!(index.keys(), 0); @@ -1608,8 +1656,8 @@ mod tests { index.insert(b"key", 123); index.insert(b"key", 456); let mut cursor = index.get_mut(b"key").unwrap(); - assert_eq!(*cursor.next().unwrap(), 123); assert_eq!(*cursor.next().unwrap(), 456); + assert_eq!(*cursor.next().unwrap(), 123); cursor.insert(789); assert_eq!(cursor.next(), None); cursor.insert(999); @@ -1656,7 +1704,7 @@ mod tests { index.insert(b"key", 123); index.insert(b"key", 456); let mut cursor = index.get_mut(b"key").unwrap(); - assert_eq!(*cursor.next().unwrap(), 123); + assert_eq!(*cursor.next().unwrap(), 456); cursor.delete(); cursor.delete(); } @@ -1686,8 +1734,8 @@ mod tests { index.insert(b"key", 2); { let mut cursor = index.get_mut(b"key").unwrap(); - assert_eq!(*cursor.next().unwrap(), 1); assert_eq!(*cursor.next().unwrap(), 2); + assert_eq!(*cursor.next().unwrap(), 1); cursor.delete(); assert!(cursor.next().is_none()); assert!(cursor.next().is_none()); @@ -1734,10 +1782,10 @@ mod tests { index.insert(b"key", 2); index.insert(b"key", 3); let mut cur = index.get_mut(b"key").unwrap(); - assert_eq!(*cur.next().unwrap(), 1); assert_eq!(*cur.next().unwrap(), 3); - cur.delete(); assert_eq!(*cur.next().unwrap(), 2); + cur.delete(); + assert_eq!(*cur.next().unwrap(), 1); assert!(cur.next().is_none()); assert!(cur.next().is_none()); } @@ -1766,14 +1814,14 @@ mod tests { index.insert(b"key", 3); { let mut cur = index.get_mut(b"key").unwrap(); - assert_eq!(*cur.next().unwrap(), 1); - cur.delete(); assert_eq!(*cur.next().unwrap(), 3); + cur.delete(); assert_eq!(*cur.next().unwrap(), 2); + assert_eq!(*cur.next().unwrap(), 1); assert!(cur.next().is_none()); assert!(cur.next().is_none()); } - assert_eq!(index.get(b"key").copied().collect::>(), vec![3, 2]); + assert_eq!(index.get(b"key").copied().collect::>(), vec![2, 1]); } #[test_traced] @@ -1800,21 +1848,21 @@ mod tests { index.insert(b"key", 3); assert_eq!( index.get(b"key").copied().collect::>(), - vec![1, 3, 2] + vec![3, 2, 1] ); { let mut cur = index.get_mut(b"key").unwrap(); - assert_eq!(*cur.next().unwrap(), 1); - cur.delete(); assert_eq!(*cur.next().unwrap(), 3); - cur.insert(4); + cur.delete(); assert_eq!(*cur.next().unwrap(), 2); + cur.insert(4); + assert_eq!(*cur.next().unwrap(), 1); assert!(cur.next().is_none()); assert!(cur.next().is_none()); } assert_eq!( index.get(b"key").copied().collect::>(), - vec![3, 4, 2] + vec![2, 4, 1] ); } @@ -1855,9 +1903,9 @@ mod tests { index.insert(b"key", 1); index.insert(b"key", 2); let mut cur = index.get_mut(b"key").unwrap(); - assert_eq!(*cur.next().unwrap(), 1); - cur.insert(99); assert_eq!(*cur.next().unwrap(), 2); + cur.insert(99); + assert_eq!(*cur.next().unwrap(), 1); assert!(cur.next().is_none()); } @@ -1898,7 +1946,7 @@ mod tests { index.insert(b"key", 10); index.insert(b"key", 20); let mut cur = index.get_mut(b"key").unwrap(); - assert_eq!(*cur.next().unwrap(), 10); + assert_eq!(*cur.next().unwrap(), 20); cur.insert(15); cur.delete(); } @@ -1943,8 +1991,8 @@ mod tests { index.insert(b"key", 10); index.insert(b"key", 20); let mut cur = index.get_mut(b"key").unwrap(); - assert_eq!(*cur.next().unwrap(), 10); assert_eq!(*cur.next().unwrap(), 20); + assert_eq!(*cur.next().unwrap(), 10); cur.delete(); cur.insert(15); } @@ -1989,7 +2037,7 @@ mod tests { index.insert(b"key", 10); index.insert(b"key", 20); let mut cur = index.get_mut(b"key").unwrap(); - assert_eq!(*cur.next().unwrap(), 10); + assert_eq!(*cur.next().unwrap(), 20); cur.insert(15); cur.insert(25); } @@ -2080,7 +2128,7 @@ mod tests { }); } - fn run_index_drop_mid_iteration_relinks>(index: &mut I) { + fn run_index_drop_mid_iteration_preserves_chain>(index: &mut I) { for i in 0..5 { index.insert(b"z", i); } @@ -2091,39 +2139,39 @@ mod tests { } assert_eq!( index.get(b"z").copied().collect::>(), - vec![0, 4, 3, 2, 1] + vec![4, 3, 2, 1, 0] ); } #[test_traced] - fn test_hash_index_drop_mid_iteration_relinks() { + fn test_hash_index_drop_mid_iteration_preserves_chain() { let runner = deterministic::Runner::default(); runner.start(|context| async move { let mut index = new_unordered(context); - run_index_drop_mid_iteration_relinks(&mut index); + run_index_drop_mid_iteration_preserves_chain(&mut index); }); } #[test_traced] - fn test_ordered_index_drop_mid_iteration_relinks() { + fn test_ordered_index_drop_mid_iteration_preserves_chain() { let runner = deterministic::Runner::default(); runner.start(|context| async move { let mut index = new_ordered(context); - run_index_drop_mid_iteration_relinks(&mut index); + run_index_drop_mid_iteration_preserves_chain(&mut index); }); } #[test_traced] - fn test_partitioned_index_drop_mid_iteration_relinks() { + fn test_partitioned_index_drop_mid_iteration_preserves_chain() { let runner = deterministic::Runner::default(); runner.start(|context| async move { { let mut index = new_partitioned_unordered(context.child("unordered")); - run_index_drop_mid_iteration_relinks(&mut index); + run_index_drop_mid_iteration_preserves_chain(&mut index); } { let mut index = new_partitioned_ordered(context.child("ordered")); - run_index_drop_mid_iteration_relinks(&mut index); + run_index_drop_mid_iteration_preserves_chain(&mut index); } }); } diff --git a/storage/src/index/ordered.rs b/storage/src/index/ordered.rs index f7c26618f6d..922b0fe7482 100644 --- a/storage/src/index/ordered.rs +++ b/storage/src/index/ordered.rs @@ -6,7 +6,7 @@ use crate::{ index::{ - storage::{Cursor as CursorImpl, ImmutableCursor, IndexEntry, Record}, + storage::{insert_front, Cursor as CursorImpl, ImmutableCursor, IndexEntry, Record}, Cursor as CursorTrait, Ordered, Unordered, }, translator::Translator, @@ -30,9 +30,6 @@ use std::{ impl IndexEntry for BTreeOccupiedEntry<'_, K, Record> { - fn get(&self) -> &V { - &self.get().value - } fn get_mut(&mut self) -> &mut Record { self.get_mut() } @@ -251,11 +248,9 @@ impl Unordered for Index { fn insert(&mut self, key: &[u8], value: V) { let k = self.translator.transform(key); match self.map.entry(k) { - BTreeEntry::Occupied(entry) => { - let mut cursor = - Cursor::<'_, T::Key, V>::new(entry, &self.keys, &self.items, &self.pruned); - cursor.next(); - cursor.insert(value); + BTreeEntry::Occupied(mut entry) => { + insert_front(entry.get_mut(), value); + self.items.inc(); } BTreeEntry::Vacant(entry) => { Self::create(&self.keys, &self.items, entry, value); @@ -267,11 +262,9 @@ impl Unordered for Index { let k = self.translator.transform(key); match self.map.entry(k) { BTreeEntry::Occupied(entry) => { - // Get entry + // Remove anything that is prunable. let mut cursor = Cursor::<'_, T::Key, V>::new(entry, &self.keys, &self.items, &self.pruned); - - // Remove anything that is prunable. cursor.prune(&predicate); // Add our new value (if not prunable). @@ -280,32 +273,29 @@ impl Unordered for Index { } } BTreeEntry::Vacant(entry) => { - Self::create(&self.keys, &self.items, entry, value); + // Create the entry only if the new value is not prunable. + if !predicate(&value) { + Self::create(&self.keys, &self.items, entry, value); + } } } } - fn prune(&mut self, key: &[u8], predicate: impl Fn(&V) -> bool) { + fn remove(&mut self, key: &[u8]) { let k = self.translator.transform(key); - match self.map.entry(k) { - BTreeEntry::Occupied(entry) => { - // Get cursor - let mut cursor = - Cursor::<'_, T::Key, V>::new(entry, &self.keys, &self.items, &self.pruned); - - // Remove anything that is prunable. - cursor.prune(&predicate); + if let Some(mut record) = self.map.remove(&k) { + // To ensure metrics are accurate, account for all conflicting values in the chain. + self.keys.dec(); + self.items.dec(); + self.pruned.inc(); + while let Some(next) = record.next.take() { + self.items.dec(); + self.pruned.inc(); + record = *next; } - BTreeEntry::Vacant(_) => {} } } - fn remove(&mut self, key: &[u8]) { - // To ensure metrics are accurate, we iterate over all conflicting values and remove them - // one-by-one (rather than just removing the entire entry). - self.prune(key, |_| true); - } - #[cfg(test)] fn keys(&self) -> usize { self.map.len() @@ -387,15 +377,15 @@ mod tests { // Next translated key to 0x0b is 1c. let (mut next, wrapped) = index.next_translated_key(&hex!("0x0b0102")).unwrap(); assert!(!wrapped); - assert_eq!(next.next().unwrap(), &21); assert_eq!(next.next().unwrap(), &22); + assert_eq!(next.next().unwrap(), &21); assert_eq!(next.next(), None); // Next translated key to 0x1b is 1c. let (mut next, wrapped) = index.next_translated_key(&hex!("0x1b010203")).unwrap(); assert!(!wrapped); - assert_eq!(next.next().unwrap(), &21); assert_eq!(next.next().unwrap(), &22); + assert_eq!(next.next().unwrap(), &21); assert_eq!(next.next(), None); // Next translated key to 0x2a is 2d. @@ -431,8 +421,8 @@ mod tests { // Previous translated key is 1c. let (mut prev, wrapped) = index.prev_translated_key(&hex!("0x1d0102")).unwrap(); assert!(!wrapped); - assert_eq!(prev.next().unwrap(), &21); assert_eq!(prev.next().unwrap(), &22); + assert_eq!(prev.next().unwrap(), &21); assert_eq!(prev.next(), None); // Previous translated key is 2d. diff --git a/storage/src/index/partitioned/ordered.rs b/storage/src/index/partitioned/ordered.rs index 3fcbaa1f67c..f89869b62da 100644 --- a/storage/src/index/partitioned/ordered.rs +++ b/storage/src/index/partitioned/ordered.rs @@ -315,10 +315,10 @@ mod tests { } let first_translated_key = index.first_translated_key().unwrap().next().unwrap(); - assert_eq!(*first_translated_key, 0); + assert_eq!(*first_translated_key, u64::MAX); let last_translated_key = index.last_translated_key().unwrap().next().unwrap(); - assert_eq!(*last_translated_key, (255u64 << 8) | 255); + assert_eq!(*last_translated_key, u64::MAX); let last = [255u8, 255u8]; let (mut iter, wrapped) = index.next_translated_key(&last).unwrap(); @@ -331,17 +331,17 @@ mod tests { if !(b1 == 255 && b2 == 255) { let (mut iter, _) = index.next_translated_key(&key).unwrap(); let next = *iter.next().unwrap(); - assert_eq!(next, ((b1 as u64) << 8 | b2 as u64) + 1); - let next = *iter.next().unwrap(); assert_eq!(next, u64::MAX); + let next = *iter.next().unwrap(); + assert_eq!(next, ((b1 as u64) << 8 | b2 as u64) + 1); assert!(iter.next().is_none()); } if !(b1 == 0 && b2 == 0) { let (mut iter, _) = index.prev_translated_key(&key).unwrap(); let prev = *iter.next().unwrap(); - assert_eq!(prev, ((b1 as u64) << 8 | b2 as u64) - 1); - let prev = *iter.next().unwrap(); assert_eq!(prev, u64::MAX); + let prev = *iter.next().unwrap(); + assert_eq!(prev, ((b1 as u64) << 8 | b2 as u64) - 1); assert!(iter.next().is_none()); } } @@ -380,15 +380,15 @@ mod tests { // Next translated key to 0x0b02 is 1c. let (mut iter, wrapped) = index.next_translated_key(&hex!("0x0b02F2")).unwrap(); assert!(!wrapped); - assert_eq!(iter.next(), Some(&21)); assert_eq!(iter.next(), Some(&22)); + assert_eq!(iter.next(), Some(&21)); assert_eq!(iter.next(), None); // Next translated key to 0x1b is 1c. let (mut iter, wrapped) = index.next_translated_key(&hex!("0x1b010203")).unwrap(); assert!(!wrapped); - assert_eq!(iter.next(), Some(&21)); assert_eq!(iter.next(), Some(&22)); + assert_eq!(iter.next(), Some(&21)); assert_eq!(iter.next(), None); // Next translated key to 0x2a is 2d. @@ -424,8 +424,8 @@ mod tests { // Previous translated key is 1c. let (mut iter, wrapped) = index.prev_translated_key(&hex!("0x1d0102")).unwrap(); assert!(!wrapped); - assert_eq!(iter.next(), Some(&21)); assert_eq!(iter.next(), Some(&22)); + assert_eq!(iter.next(), Some(&21)); assert_eq!(iter.next(), None); // Previous translated key is 2d. diff --git a/storage/src/index/storage.rs b/storage/src/index/storage.rs index 0868268eef4..f5c4a2f8305 100644 --- a/storage/src/index/storage.rs +++ b/storage/src/index/storage.rs @@ -2,6 +2,7 @@ use crate::index::Cursor as CursorTrait; use commonware_runtime::telemetry::metrics::{Counter, Gauge}; +use std::ptr::NonNull; /// Each key is mapped to a [Record] that contains a linked list of potential values for that key. /// @@ -17,8 +18,15 @@ pub(super) struct Record { pub(super) next: Option>, } +pub(super) fn insert_front(record: &mut Record, mut value: V) { + std::mem::swap(&mut record.value, &mut value); + record.next = Some(Box::new(Record { + value, + next: record.next.take(), + })); +} + pub(super) trait IndexEntry: Send + Sync { - fn get(&self) -> &V; fn get_mut(&mut self) -> &mut Record; fn remove(self); } @@ -27,50 +35,42 @@ pub(super) trait IndexEntry: Send + Sync { /// `delete()`. const MUST_CALL_NEXT: &str = "must call Cursor::next()"; -/// Panic message shown when `update()` is called after [Cursor] has returned `None` or after -/// `insert()` or `delete()` (but before `next()`). +/// Panic message shown when `update()` or `delete()` is called after [Cursor] has returned `None`. const NO_ACTIVE_ITEM: &str = "no active item in Cursor"; -/// Phases of the [Cursor] during iteration. #[derive(PartialEq, Eq)] -enum Phase { - /// Before iteration starts. - Initial, - - /// The current entry. - Entry, - /// Some item after the current entry. - Next(Box>), - - /// Iteration is done. +enum State { + /// Before first `next()` call, or immediately after `insert()`/`delete()`. + NeedNext, + /// `next()` returned a value; `update()`/`delete()` are valid. + Active, + /// `next()` returned `None`; only `insert()` is valid. Done, - /// The current entry has no valid item. - EntryDeleted, - - /// The current entry has been deleted and we've updated its value in-place - /// to be the value of the next record. - PostDeleteEntry, - /// The item has been deleted and we may be pointing to the next item. - PostDeleteNext(Option>>), - /// An item has been inserted. - PostInsert(Box>), + /// The sole element was deleted; the entry will be removed on Drop. + EntryRemoved, } -/// A cursor for [crate::index] types that can be instantiated with any [IndexEntry] implementation. +/// A cursor that traverses and mutates a linked list of [Record]s in place using raw pointers. +/// +/// Tracks `prev` (for relinking on delete) and `current` (last item returned by `next`). +/// The next element to visit is derived from `current.next` (or the entry head when +/// `current` is `None`), so no separate `upcoming` pointer is needed. +/// +/// Invariants: +/// - `entry` owns the linked list and keeps it exclusively borrowed for the cursor's lifetime. +/// - `prev` and `current`, when present, point into that list. +/// - `prev` and `current` are created only from exclusive references through `record_ptr`. +/// - When both are present, `prev.next` owns `current`. +/// - After deleting a node, `current` is moved back to the previous live node or cleared. pub(super) struct Cursor<'a, V: Eq + Send + Sync, E: IndexEntry> { - // The current phase of the cursor. - phase: Phase, - - // The current entry. + // The occupied index entry that owns the linked list while the cursor exists. entry: Option, - - // The head of the linked list of previously visited records. - past: Option>>, - // The tail of the linked list of previously visited records. - past_tail: Option<*mut Record>, - // Whether we've pushed a record with a populated `next` field to `past` (invalidates - // `past_tail`). - past_pushed_list: bool, + // The live record immediately before `current`, used to relink on non-head deletes. + prev: Option>>, + // The last record returned by `next()`. + current: Option>>, + // The current position/state of the cursor. + state: State, // Metrics. keys: &'a Gauge, @@ -79,8 +79,7 @@ pub(super) struct Cursor<'a, V: Eq + Send + Sync, E: IndexEntry> { } impl<'a, V: Eq + Send + Sync, E: IndexEntry> Cursor<'a, V, E> { - /// Creates a new [Cursor] from a mutable record reference, detaching its `next` chain for - /// iteration. + /// Creates a new [Cursor] from an occupied index entry. pub(super) const fn new( entry: E, keys: &'a Gauge, @@ -88,190 +87,149 @@ impl<'a, V: Eq + Send + Sync, E: IndexEntry> Cursor<'a, V, E> { pruned: &'a Counter, ) -> Self { Self { - phase: Phase::Initial, - entry: Some(entry), - - past: None, - past_tail: None, - past_pushed_list: false, - + prev: None, + current: None, + state: State::NeedNext, keys, items, pruned, } } - /// Pushes a [Record] to the end of `past`. - /// - /// If the record has a `next`, this function cannot be called again. - pub(super) fn past_push(&mut self, next: Box>) { - // Ensure we only push a list once (`past_tail` becomes stale). - assert!(!self.past_pushed_list); - self.past_pushed_list = next.next.is_some(); - - // Add `next` to the tail of `past`. - if let Some(past_tail) = self.past_tail { - // SAFETY: `past_tail` is always either `None` or points to a valid `Record` - // within the `self.past` linked list. We only enter this branch when `past_tail` - // is `Some`, meaning it was previously set to point to an owned node. The - // assertion verifies the invariant that `past_tail.next` is `None` before we - // append to it. - unsafe { - assert!((*past_tail).next.is_none()); - (*past_tail).next = Some(next); - let tail_next = (*past_tail).next.as_mut().unwrap(); - self.past_tail = Some(&mut **tail_next as *mut Record); - } - } else { - self.past = Some(next); - self.past_tail = self.past.as_mut().map(|b| &mut **b as *mut Record); - } + fn record_ptr(record: &mut Record) -> NonNull> { + NonNull::from(record) } - /// If we are in a phase where we could return a value, return it. - pub(super) fn value(&self) -> Option<&V> { - match &self.phase { - Phase::Initial => unreachable!(), - Phase::Entry => self.entry.as_ref().map(|e| e.get()), - Phase::Next(current) => Some(¤t.value), - Phase::Done | Phase::EntryDeleted => None, - Phase::PostDeleteEntry | Phase::PostDeleteNext(_) | Phase::PostInsert(_) => { - unreachable!() - } - } + const fn record_mut(&mut self, mut ptr: NonNull>) -> &mut Record { + // SAFETY: `ptr` was created by `record_ptr` from a record owned by `entry`, which is + // exclusively borrowed through this cursor. Cursor state clears or rewinds pointers before + // an owner is dropped. + unsafe { ptr.as_mut() } } } impl> CursorTrait for Cursor<'_, V, E> { type Value = V; - fn update(&mut self, v: V) { - match &mut self.phase { - Phase::Initial => unreachable!("{MUST_CALL_NEXT}"), - Phase::Entry => { - self.entry.as_mut().unwrap().get_mut().value = v; - } - Phase::Next(next) => { - next.value = v; - } - Phase::Done - | Phase::EntryDeleted - | Phase::PostDeleteEntry - | Phase::PostDeleteNext(_) - | Phase::PostInsert(_) => unreachable!("{NO_ACTIVE_ITEM}"), + fn next(&mut self) -> Option<&V> { + match self.state { + State::Done | State::EntryRemoved => return None, + State::NeedNext | State::Active => {} } - } - fn next(&mut self) -> Option<&V> { - match std::mem::replace(&mut self.phase, Phase::Done) { - Phase::Initial | Phase::PostDeleteEntry => { - // We must start with some entry, so this will always be some non-None value. - self.phase = Phase::Entry; - } - Phase::Entry => { - // If there is a record after, we set it to be the current record. - if let Some(next) = self.entry.as_mut().unwrap().get_mut().next.take() { - self.phase = Phase::Next(next); + // Derive the next record from `current.next` or the entry head. + let next_ptr = if let Some(current) = self.current { + match self.record_mut(current).next.as_deref_mut() { + Some(next) => Self::record_ptr(next), + None => { + self.state = State::Done; + return None; } } - Phase::Next(mut current) | Phase::PostInsert(mut current) => { - // Take the next record and push the current one to the past list. - let next = current.next.take(); - self.past_push(current); + } else { + Self::record_ptr(self.entry.as_mut().unwrap().get_mut()) + }; - // Set the next record to be the current record. - if let Some(next) = next { - self.phase = Phase::Next(next); - } - } - Phase::Done => {} - Phase::EntryDeleted => { - self.phase = Phase::EntryDeleted; - } - Phase::PostDeleteNext(current) => { - // If the stale value is some, we set it to be the current record. - if let Some(current) = current { - self.phase = Phase::Next(current); - } - } + self.prev = self.current; + self.current = Some(next_ptr); + self.state = State::Active; + Some(&self.record_mut(next_ptr).value) + } + + fn update(&mut self, v: V) { + match self.state { + State::NeedNext => panic!("{MUST_CALL_NEXT}"), + State::Done | State::EntryRemoved => panic!("{NO_ACTIVE_ITEM}"), + State::Active => {} } - self.value() + assert!(self.current.is_some(), "Active state requires current"); + let current = self.current.unwrap(); + self.record_mut(current).value = v; } fn insert(&mut self, v: V) { - self.items.inc(); - match std::mem::replace(&mut self.phase, Phase::Done) { - Phase::Initial => unreachable!("{MUST_CALL_NEXT}"), - Phase::Entry => { - // Create a new record that points to entry's next. - let new = Box::new(Record { - value: v, - next: self.entry.as_mut().unwrap().get_mut().next.take(), - }); - - // Set the phase to the new record. - self.phase = Phase::PostInsert(new); + match self.state { + State::NeedNext => panic!("{MUST_CALL_NEXT}"), + State::Active => { + self.items.inc(); + assert!(self.current.is_some(), "Active state requires current"); + let current = self.current.unwrap(); + let inserted = { + let current_record = self.record_mut(current); + let new = Box::new(Record { + value: v, + next: current_record.next.take(), + }); + current_record.next = Some(new); + Self::record_ptr(current_record.next.as_deref_mut().unwrap()) + }; + // Advance past the inserted node so next() returns the element after it. + self.prev = self.current; + self.current = Some(inserted); + self.state = State::NeedNext; } - Phase::Next(mut current) => { - // Take next. - let next = current.next.take(); - - // Add current to the past list. - self.past_push(current); - - // Create a new record that points to the next's next. - let new = Box::new(Record { value: v, next }); - self.phase = Phase::PostInsert(new); + State::EntryRemoved => { + // Re-populate the entry that was emptied by delete. + self.items.inc(); + let entry_record = self.entry.as_mut().unwrap().get_mut(); + entry_record.value = v; + entry_record.next = None; + self.prev = None; + self.current = Some(Self::record_ptr(entry_record)); + self.state = State::Done; } - Phase::Done => { - // If we are done, we need to create a new record and - // immediately push it to the past list. - let new = Box::new(Record { - value: v, - next: None, - }); - self.past_push(new); - } - Phase::EntryDeleted => { - // If entry is deleted, we need to update it. - self.entry.as_mut().unwrap().get_mut().value = v; - - // We don't consider overwriting a deleted entry a collision. - } - Phase::PostDeleteEntry | Phase::PostDeleteNext(_) | Phase::PostInsert(_) => { - unreachable!("{MUST_CALL_NEXT}") + State::Done => { + self.items.inc(); + let last = self.current.or(self.prev); + assert!(last.is_some(), "Done state requires current or prev"); + let inserted = { + let last_record = self.record_mut(last.unwrap()); + last_record.next = Some(Box::new(Record { + value: v, + next: None, + })); + Self::record_ptr(last_record.next.as_deref_mut().unwrap()) + }; + self.prev = last; + self.current = Some(inserted); + self.state = State::Done; } } } fn delete(&mut self) { + match self.state { + State::NeedNext => panic!("{MUST_CALL_NEXT}"), + State::Done | State::EntryRemoved => panic!("{NO_ACTIVE_ITEM}"), + State::Active => {} + } self.pruned.inc(); self.items.dec(); - match std::mem::replace(&mut self.phase, Phase::Done) { - Phase::Initial => unreachable!("{MUST_CALL_NEXT}"), - Phase::Entry => { - // Attempt to overwrite the entry with the next value. - let entry = self.entry.as_mut().unwrap().get_mut(); - if let Some(next) = entry.next.take() { - entry.value = next.value; - entry.next = next.next; - self.phase = Phase::PostDeleteEntry; - return; - } - // If there is no next, we consider the entry deleted. - self.phase = Phase::EntryDeleted; - // We wait to update metrics until `drop()`. - } - Phase::Next(mut current) => { - // Drop current instead of pushing it to the past list. - let next = current.next.take(); - self.phase = Phase::PostDeleteNext(next); - } - Phase::Done | Phase::EntryDeleted => unreachable!("{NO_ACTIVE_ITEM}"), - Phase::PostDeleteEntry | Phase::PostDeleteNext(_) | Phase::PostInsert(_) => { - unreachable!("{MUST_CALL_NEXT}") + assert!(self.current.is_some(), "Active state requires current"); + let current = self.current.unwrap(); + + if let Some(prev) = self.prev { + // Deleting a non-head node: relink prev.next to current.next. + let next = self.record_mut(current).next.take(); + self.record_mut(prev).next = next; + self.current = self.prev; + self.prev = None; + self.state = State::NeedNext; + } else { + // Deleting the head node (the entry record itself). + let head = self.record_mut(current); + if let Some(next) = head.next.take() { + // Promote the next record into the head position. + head.value = next.value; + head.next = next.next; + self.current = None; + self.state = State::NeedNext; + } else { + // Sole element deleted. + self.current = None; + self.state = State::EntryRemoved; } } } @@ -286,69 +244,20 @@ impl> CursorTrait for Cursor<'_, V, E> { } } -// SAFETY: [Send] is safe because the raw pointer `past_tail` only ever points to heap memory -// owned by `self.past`. Since the pointer's referent is moved along with the [Cursor], no data -// races can occur. The `where` clause ensures all generic parameters are also [Send]. -unsafe impl<'a, V, E> Send for Cursor<'a, V, E> -where - V: Eq + Send + Sync, - E: IndexEntry, -{ -} - -// SAFETY: [Sync] is safe because the raw pointer `past_tail` only ever points to heap memory -// owned by `self.past`. Since `past_tail` is never dereferenced through shared references in -// a way that could cause data races, and the `where` clause ensures all generic parameters -// are also [Sync], it is safe to share references to [Cursor] across threads. -unsafe impl<'a, V, E> Sync for Cursor<'a, V, E> -where - V: Eq + Send + Sync, - E: IndexEntry, -{ -} +// SAFETY: `NonNull` is not `Send`, so this cannot be derived automatically. `prev` and `current` +// are only bookkeeping pointers into the linked list owned by `entry`. Moving the cursor to another +// thread also moves `entry`, keeping the list alive and exclusively borrowed by the cursor. +unsafe impl> Send for Cursor<'_, V, E> {} +// SAFETY: `NonNull` is not `Sync`, so this cannot be derived automatically. Sharing a cursor does +// not grant access to the records without `&mut self`, and `entry` keeps the list alive and +// exclusively borrowed for the cursor's lifetime. +unsafe impl> Sync for Cursor<'_, V, E> {} impl> Drop for Cursor<'_, V, E> { fn drop(&mut self) { - // Take the entry. - let mut entry = self.entry.take().unwrap(); - - // If there is a dangling next, we should add it to past. - match std::mem::replace(&mut self.phase, Phase::Done) { - Phase::Initial | Phase::Entry => { - // No action needed. - } - Phase::Next(next) => { - // If there is a next, we should add it to past. - self.past_push(next); - } - Phase::Done => { - // No action needed. - } - Phase::EntryDeleted => { - // If the entry is deleted, we should remove it. - self.keys.dec(); - entry.remove(); - return; - } - Phase::PostDeleteEntry => { - // No action needed. - } - Phase::PostDeleteNext(Some(next)) => { - // If there is a stale record, we should add it to past. - self.past_push(next); - } - Phase::PostDeleteNext(None) => { - // No action needed. - } - Phase::PostInsert(next) => { - // If there is a current record, we should add it to past. - self.past_push(next); - } - } - - // Attach the tip of past to the entry. - if let Some(past) = self.past.take() { - entry.get_mut().next = Some(past); + if self.state == State::EntryRemoved { + self.keys.dec(); + self.entry.take().unwrap().remove(); } } } diff --git a/storage/src/index/unordered.rs b/storage/src/index/unordered.rs index 79fc308164e..616133b7c0e 100644 --- a/storage/src/index/unordered.rs +++ b/storage/src/index/unordered.rs @@ -4,7 +4,7 @@ use crate::{ index::{ - storage::{Cursor as CursorImpl, ImmutableCursor, IndexEntry, Record}, + storage::{insert_front, Cursor as CursorImpl, ImmutableCursor, IndexEntry, Record}, Cursor as CursorTrait, Unordered, }, translator::Translator, @@ -25,9 +25,6 @@ const INITIAL_CAPACITY: usize = 256; /// Implementation of [IndexEntry] for [OccupiedEntry]. impl IndexEntry for OccupiedEntry<'_, K, Record> { - fn get(&self) -> &V { - &self.get().value - } fn get_mut(&mut self) -> &mut Record { self.get_mut() } @@ -171,11 +168,9 @@ impl Unordered for Index { fn insert(&mut self, key: &[u8], v: V) { let k = self.translator.transform(key); match self.map.entry(k) { - Entry::Occupied(entry) => { - let mut cursor = - Cursor::<'_, T::Key, V>::new(entry, &self.keys, &self.items, &self.pruned); - cursor.next(); - cursor.insert(v); + Entry::Occupied(mut entry) => { + insert_front(entry.get_mut(), v); + self.items.inc(); } Entry::Vacant(entry) => { Self::create(&self.keys, &self.items, entry, v); @@ -187,10 +182,9 @@ impl Unordered for Index { let k = self.translator.transform(key); match self.map.entry(k) { Entry::Occupied(entry) => { - // Get entry + // Remove anything that is prunable. let mut cursor = Cursor::<'_, T::Key, V>::new(entry, &self.keys, &self.items, &self.pruned); - cursor.prune(&predicate); // Add our new value (if not prunable). @@ -199,31 +193,29 @@ impl Unordered for Index { } } Entry::Vacant(entry) => { - Self::create(&self.keys, &self.items, entry, value); + // Create the entry only if the new value is not prunable. + if !predicate(&value) { + Self::create(&self.keys, &self.items, entry, value); + } } } } - fn prune(&mut self, key: &[u8], predicate: impl Fn(&V) -> bool) { + fn remove(&mut self, key: &[u8]) { let k = self.translator.transform(key); - match self.map.entry(k) { - Entry::Occupied(entry) => { - // Get cursor - let mut cursor = - Cursor::<'_, T::Key, V>::new(entry, &self.keys, &self.items, &self.pruned); - - cursor.prune(&predicate); + if let Some(mut record) = self.map.remove(&k) { + // To ensure metrics are accurate, account for all conflicting values in the chain. + self.keys.dec(); + self.items.dec(); + self.pruned.inc(); + while let Some(next) = record.next.take() { + self.items.dec(); + self.pruned.inc(); + record = *next; } - Entry::Vacant(_) => {} } } - fn remove(&mut self, key: &[u8]) { - // To ensure metrics are accurate, we iterate over all conflicting values and remove them - // one-by-one (rather than just removing the entire entry). - self.prune(key, |_| true); - } - #[cfg(test)] fn keys(&self) -> usize { self.map.len() diff --git a/storage/src/qmdb/any/batch.rs b/storage/src/qmdb/any/batch.rs index 451454e3376..58443226d21 100644 --- a/storage/src/qmdb/any/batch.rs +++ b/storage/src/qmdb/any/batch.rs @@ -314,25 +314,80 @@ impl<'a, K: Ord, F: Family, V> DiffMerge<'a, K, F, V> { cursors: streams.into_iter().map(|s| (s, 0)).collect(), } } + + fn peek_key(cursor: &(&'a DiffSlice, usize)) -> Option<&'a K> { + cursor.0.get(cursor.1).map(|(k, _)| k) + } } impl<'a, K: Ord, F: Family, V> Iterator for DiffMerge<'a, K, F, V> { type Item = (&'a K, &'a DiffEntry); fn next(&mut self) -> Option { + match self.cursors.len() { + 0 => None, + 1 => { + let (slice, pos) = &mut self.cursors[0]; + let (k, entry) = slice.get(*pos)?; + *pos += 1; + Some((k, entry)) + } + 2 => { + let ka = Self::peek_key(&self.cursors[0]); + let kb = Self::peek_key(&self.cursors[1]); + match (ka, kb) { + (Some(a), Some(b)) => match a.cmp(b) { + core::cmp::Ordering::Less => { + let (slice, pos) = &mut self.cursors[0]; + let item = &slice[*pos]; + *pos += 1; + Some((&item.0, &item.1)) + } + core::cmp::Ordering::Greater => { + let (slice, pos) = &mut self.cursors[1]; + let item = &slice[*pos]; + *pos += 1; + Some((&item.0, &item.1)) + } + core::cmp::Ordering::Equal => { + let (slice, pos) = &mut self.cursors[0]; + let item = &slice[*pos]; + *pos += 1; + self.cursors[1].1 += 1; + Some((&item.0, &item.1)) + } + }, + (Some(_), None) => { + let (slice, pos) = &mut self.cursors[0]; + let item = &slice[*pos]; + *pos += 1; + Some((&item.0, &item.1)) + } + (None, Some(_)) => { + let (slice, pos) = &mut self.cursors[1]; + let item = &slice[*pos]; + *pos += 1; + Some((&item.0, &item.1)) + } + (None, None) => None, + } + } + _ => self.next_general(), + } + } +} + +impl<'a, K: Ord, F: Family, V> DiffMerge<'a, K, F, V> { + fn next_general(&mut self) -> Option<(&'a K, &'a DiffEntry)> { let n = self.cursors.len(); let mut winner: Option = None; for level in 0..n { - let (slice, pos) = self.cursors[level]; - let Some((k, _)) = slice.get(pos) else { + let Some(k) = Self::peek_key(&self.cursors[level]) else { continue; }; let better = match winner { None => true, - Some(w) => { - let (ws, wpos) = self.cursors[w]; - *k < ws[wpos].0 - } + Some(w) => *k < *Self::peek_key(&self.cursors[w]).unwrap(), }; if better { winner = Some(level); @@ -340,10 +395,10 @@ impl<'a, K: Ord, F: Family, V> Iterator for DiffMerge<'a, K, F, V> { } let level = winner?; let (slice, pos) = self.cursors[level]; - for inner in 0..n { - let (s, p) = self.cursors[inner]; - if s.get(p).is_some_and(|(k, _)| *k == slice[pos].0) { - self.cursors[inner].1 += 1; + let winning_key = &slice[pos].0; + for cursor in &mut self.cursors { + if Self::peek_key(cursor).is_some_and(|k| k == winning_key) { + cursor.1 += 1; } } Some((&slice[pos].0, &slice[pos].1))