Skip to content

Commit 5042984

Browse files
committed
shardtree: Add the ability to avoid pruning specific checkpoints.
1 parent 4d797cc commit 5042984

File tree

7 files changed

+163
-48
lines changed

7 files changed

+163
-48
lines changed

shardtree/CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ and this project adheres to Rust's notion of
77

88
## Unreleased
99

10+
## Added
11+
* `shardtree::store::ShardStore::{ensure_retained, ensured_retained_count, should_retain}`
12+
13+
## Changed
14+
* `shardtree::store::ShardStore::with_checkpoints` no longer takes its `self`
15+
reference argument as mutable.
16+
1017
## [0.2.0] - 2023-11-07
1118

1219
## Added

shardtree/src/batch.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -485,12 +485,12 @@ mod tests {
485485
ShardTree<MemoryShardStore<String, usize>, 6, 3>,
486486
ShardTree<MemoryShardStore<String, usize>, 6, 3>,
487487
) {
488-
let max_checkpoints = 10;
488+
let min_checkpoints_to_retain = 10;
489489
let start = Position::from(0);
490490
let end = start + leaves.len() as u64;
491491

492492
// Construct a tree using `ShardTree::insert_tree`.
493-
let mut left = ShardTree::new(MemoryShardStore::empty(), max_checkpoints);
493+
let mut left = ShardTree::new(MemoryShardStore::empty(), min_checkpoints_to_retain);
494494
if let Some(BatchInsertionResult {
495495
subtree,
496496
checkpoints,
@@ -503,7 +503,7 @@ mod tests {
503503
}
504504

505505
// Construct a tree using `ShardTree::batch_insert`.
506-
let mut right = ShardTree::new(MemoryShardStore::empty(), max_checkpoints);
506+
let mut right = ShardTree::new(MemoryShardStore::empty(), min_checkpoints_to_retain);
507507
right.batch_insert(start, leaves.into_iter()).unwrap();
508508

509509
(left, right)

shardtree/src/lib.rs

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ mod legacy;
6666
pub struct ShardTree<S: ShardStore, const DEPTH: u8, const SHARD_HEIGHT: u8> {
6767
/// The vector of tree shards.
6868
store: S,
69-
/// The maximum number of checkpoints to retain before pruning.
70-
max_checkpoints: usize,
69+
/// The minumum number of checkpoints to retain when pruning.
70+
min_checkpoints_to_retain: usize,
7171
}
7272

7373
impl<
@@ -79,10 +79,10 @@ impl<
7979
> ShardTree<S, DEPTH, SHARD_HEIGHT>
8080
{
8181
/// Creates a new empty tree.
82-
pub fn new(store: S, max_checkpoints: usize) -> Self {
82+
pub fn new(store: S, min_checkpoints_to_retain: usize) -> Self {
8383
Self {
8484
store,
85-
max_checkpoints,
85+
min_checkpoints_to_retain,
8686
}
8787
}
8888

@@ -438,14 +438,19 @@ impl<
438438
.checkpoint_count()
439439
.map_err(ShardTreeError::Storage)?;
440440
trace!(
441-
"Tree has {} checkpoints, max is {}",
441+
"Tree has {} checkpoints, min to be retained is {}",
442442
checkpoint_count,
443-
self.max_checkpoints,
443+
self.min_checkpoints_to_retain,
444444
);
445-
if checkpoint_count > self.max_checkpoints {
445+
let retain_count = self.min_checkpoints_to_retain
446+
+ self
447+
.store
448+
.ensured_retained_count()
449+
.map_err(ShardTreeError::Storage)?;
450+
if checkpoint_count > retain_count {
446451
// Batch removals by subtree & create a list of the checkpoint identifiers that
447452
// will be removed from the checkpoints map.
448-
let remove_count = checkpoint_count - self.max_checkpoints;
453+
let remove_count = checkpoint_count - retain_count;
449454
let mut checkpoints_to_delete = vec![];
450455
let mut clear_positions: BTreeMap<Address, BTreeMap<Position, RetentionFlags>> =
451456
BTreeMap::new();
@@ -454,8 +459,10 @@ impl<
454459
// When removing is true, we are iterating through the range of
455460
// checkpoints being removed. When remove is false, we are
456461
// iterating through the range of checkpoints that are being
457-
// retained.
458-
let removing = checkpoints_to_delete.len() < remove_count;
462+
// retained, or skipping over a particular checkpoint that we
463+
// have been explicitly asked to retain.
464+
let removing = checkpoints_to_delete.len() < remove_count
465+
&& !self.store.should_retain(cid)?;
459466

460467
if removing {
461468
checkpoints_to_delete.push(cid.clone());
@@ -1177,9 +1184,9 @@ impl<
11771184
/// Make a marked leaf at a position eligible to be pruned.
11781185
///
11791186
/// If the checkpoint associated with the specified identifier does not exist because the
1180-
/// corresponding checkpoint would have been more than `max_checkpoints` deep, the removal is
1181-
/// recorded as of the first existing checkpoint and the associated leaves will be pruned when
1182-
/// that checkpoint is subsequently removed.
1187+
/// corresponding checkpoint would have been more than `min_checkpoints_to_retain` deep, the
1188+
/// removal is recorded as of the first existing checkpoint and the associated leaves will be
1189+
/// pruned when that checkpoint is subsequently removed.
11831190
///
11841191
/// Returns `Ok(true)` if a mark was successfully removed from the leaf at the specified
11851192
/// position, `Ok(false)` if the tree does not contain a leaf at the specified position or is
@@ -1253,7 +1260,7 @@ mod tests {
12531260
};
12541261

12551262
use crate::{
1256-
store::memory::MemoryShardStore,
1263+
store::{memory::MemoryShardStore, ShardStore},
12571264
testing::{
12581265
arb_char_str, arb_shardtree, check_shard_sizes, check_shardtree_insertion,
12591266
check_witness_with_pruned_subtrees,
@@ -1355,21 +1362,57 @@ mod tests {
13551362
),
13561363
Ok(()),
13571364
);
1365+
1366+
// Append a leaf we want to retain
1367+
assert_eq!(tree.append('e'.to_string(), Retention::Marked), Ok(()),);
1368+
1369+
// Now a few more leaves and then checkpoint
1370+
for c in 'f'..='i' {
1371+
tree.append(c.to_string(), Retention::Ephemeral).unwrap();
1372+
}
1373+
1374+
// Checkpoint the tree. We'll want to retain this checkpoint.
1375+
assert_eq!(tree.checkpoint(12), Ok(true));
1376+
tree.store.ensure_retained(12).unwrap();
1377+
1378+
// Simulate adding yet another block
1379+
for c in 'j'..='m' {
1380+
tree.append(c.to_string(), Retention::Ephemeral).unwrap();
1381+
}
1382+
1383+
assert_eq!(tree.checkpoint(13), Ok(true));
1384+
1385+
// Witness `e` as of checkpoint 12
1386+
let e_witness_12 = tree
1387+
.witness_at_checkpoint_id(Position::from(4), &12)
1388+
.unwrap();
1389+
1390+
// Now add some more checkpoints, which would ordinarily cause checkpoint 12
1391+
// to be pruned (but will not, because we explicitly retained it.)
1392+
for i in 14..24 {
1393+
assert_eq!(tree.checkpoint(i), Ok(true));
1394+
}
1395+
1396+
// Verify that we can still compute the same root
1397+
assert_matches!(
1398+
tree.witness_at_checkpoint_id(Position::from(4), &12),
1399+
Ok(w) if w == e_witness_12
1400+
);
13581401
}
13591402

13601403
// Combined tree tests
13611404
#[allow(clippy::type_complexity)]
13621405
fn new_combined_tree<H: Hashable + Ord + Clone + core::fmt::Debug>(
1363-
max_checkpoints: usize,
1406+
min_checkpoints_to_retain: usize,
13641407
) -> CombinedTree<
13651408
H,
13661409
usize,
13671410
CompleteTree<H, usize, 4>,
13681411
ShardTree<MemoryShardStore<H, usize>, 4, 3>,
13691412
> {
13701413
CombinedTree::new(
1371-
CompleteTree::new(max_checkpoints),
1372-
ShardTree::new(MemoryShardStore::empty(), max_checkpoints),
1414+
CompleteTree::new(min_checkpoints_to_retain),
1415+
ShardTree::new(MemoryShardStore::empty(), min_checkpoints_to_retain),
13731416
)
13741417
}
13751418

shardtree/src/prunable.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ bitflags! {
2323
/// [`MARKED`]: RetentionFlags::MARKED
2424
const EPHEMERAL = 0b00000000;
2525

26-
/// A leaf with `CHECKPOINT` retention can be pruned when there are more than `max_checkpoints`
27-
/// additional checkpoint leaves, if it is not also a marked leaf.
26+
/// A leaf with `CHECKPOINT` retention can be pruned when there are more than
27+
/// `min_checkpoints_to_retain` additional checkpoints, if it is not also a marked leaf.
2828
const CHECKPOINT = 0b00000001;
2929

3030
/// A leaf with `MARKED` retention can be pruned only as a consequence of an explicit deletion
@@ -34,10 +34,12 @@ bitflags! {
3434
}
3535

3636
impl RetentionFlags {
37+
/// Returns whether the [`RetentionFlags::CHECKPOINT`] flag is set.
3738
pub fn is_checkpoint(&self) -> bool {
3839
(*self & RetentionFlags::CHECKPOINT) == RetentionFlags::CHECKPOINT
3940
}
4041

42+
/// Returns whether the [`RetentionFlags::MARKED`] flag is set.
4143
pub fn is_marked(&self) -> bool {
4244
(*self & RetentionFlags::MARKED) == RetentionFlags::MARKED
4345
}

shardtree/src/store.rs

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,21 @@ pub trait ShardStore {
9393
checkpoint: Checkpoint,
9494
) -> Result<(), Self::Error>;
9595

96+
/// Records the provided checkpoint ID as corresponding to a checkpoint to be retained across
97+
/// pruning operations.
98+
///
99+
/// Implementations of this method must add the provided checkpoint ID to the set of
100+
/// checkpoints to be retained even if no such checkpoint currently exists in the backing
101+
/// store.
102+
fn ensure_retained(&mut self, checkpoint_id: Self::CheckpointId) -> Result<(), Self::Error>;
103+
104+
/// Returns the number of checkpoints explicitly retained using [`ensure_retained`].
105+
fn ensured_retained_count(&self) -> Result<usize, Self::Error>;
106+
107+
/// Returns the set of identifiers for checkpoints that should be exempt from pruning
108+
/// operations.
109+
fn should_retain(&self, cid: &Self::CheckpointId) -> Result<bool, Self::Error>;
110+
96111
/// Returns the number of checkpoints maintained by the data store
97112
fn checkpoint_count(&self) -> Result<usize, Self::Error>;
98113

@@ -112,7 +127,7 @@ pub trait ShardStore {
112127

113128
/// Iterates in checkpoint ID order over the first `limit` checkpoints, applying the
114129
/// given callback to each.
115-
fn with_checkpoints<F>(&mut self, limit: usize, callback: F) -> Result<(), Self::Error>
130+
fn with_checkpoints<F>(&self, limit: usize, callback: F) -> Result<(), Self::Error>
116131
where
117132
F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>;
118133

@@ -165,6 +180,10 @@ impl<S: ShardStore> ShardStore for &mut S {
165180
S::get_shard_roots(*self)
166181
}
167182

183+
fn truncate(&mut self, from: Address) -> Result<(), Self::Error> {
184+
S::truncate(*self, from)
185+
}
186+
168187
fn get_cap(&self) -> Result<PrunableTree<Self::H>, Self::Error> {
169188
S::get_cap(*self)
170189
}
@@ -173,10 +192,6 @@ impl<S: ShardStore> ShardStore for &mut S {
173192
S::put_cap(*self, cap)
174193
}
175194

176-
fn truncate(&mut self, from: Address) -> Result<(), Self::Error> {
177-
S::truncate(*self, from)
178-
}
179-
180195
fn min_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
181196
S::min_checkpoint_id(self)
182197
}
@@ -193,6 +208,18 @@ impl<S: ShardStore> ShardStore for &mut S {
193208
S::add_checkpoint(self, checkpoint_id, checkpoint)
194209
}
195210

211+
fn ensure_retained(&mut self, checkpoint_id: Self::CheckpointId) -> Result<(), Self::Error> {
212+
S::ensure_retained(self, checkpoint_id)
213+
}
214+
215+
fn ensured_retained_count(&self) -> Result<usize, Self::Error> {
216+
S::ensured_retained_count(self)
217+
}
218+
219+
fn should_retain(&self, cid: &Self::CheckpointId) -> Result<bool, Self::Error> {
220+
S::should_retain(self, cid)
221+
}
222+
196223
fn checkpoint_count(&self) -> Result<usize, Self::Error> {
197224
S::checkpoint_count(self)
198225
}
@@ -211,11 +238,11 @@ impl<S: ShardStore> ShardStore for &mut S {
211238
S::get_checkpoint(self, checkpoint_id)
212239
}
213240

214-
fn with_checkpoints<F>(&mut self, limit: usize, callback: F) -> Result<(), Self::Error>
241+
fn with_checkpoints<F>(&self, limit: usize, mut callback: F) -> Result<(), Self::Error>
215242
where
216243
F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>,
217244
{
218-
S::with_checkpoints(self, limit, callback)
245+
S::with_checkpoints(self, limit, |cid, c| callback(cid, c))
219246
}
220247

221248
fn update_checkpoint_with<F>(

shardtree/src/store/caching.rs

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ where
3737
S::CheckpointId: Clone + Ord,
3838
{
3939
/// Loads a `CachingShardStore` from the given backend.
40-
pub fn load(mut backend: S) -> Result<Self, S::Error> {
40+
pub fn load(backend: S) -> Result<Self, S::Error> {
4141
let mut cache = MemoryShardStore::empty();
4242

4343
for shard_root in backend.get_shard_roots()? {
@@ -94,6 +94,9 @@ where
9494
},
9595
)
9696
.unwrap();
97+
for cid in self.cache.checkpoints_to_retain() {
98+
self.backend.ensure_retained(cid.clone())?;
99+
}
97100
for (checkpoint_id, checkpoint) in checkpoints {
98101
self.backend.add_checkpoint(checkpoint_id, checkpoint)?;
99102
}
@@ -144,6 +147,14 @@ where
144147
self.cache.put_cap(cap)
145148
}
146149

150+
fn min_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
151+
self.cache.min_checkpoint_id()
152+
}
153+
154+
fn max_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
155+
self.cache.max_checkpoint_id()
156+
}
157+
147158
fn add_checkpoint(
148159
&mut self,
149160
checkpoint_id: Self::CheckpointId,
@@ -152,15 +163,20 @@ where
152163
self.cache.add_checkpoint(checkpoint_id, checkpoint)
153164
}
154165

155-
fn checkpoint_count(&self) -> Result<usize, Self::Error> {
156-
self.cache.checkpoint_count()
166+
fn ensure_retained(&mut self, checkpoint_id: Self::CheckpointId) -> Result<(), Self::Error> {
167+
self.cache.ensure_retained(checkpoint_id)
157168
}
158169

159-
fn get_checkpoint(
160-
&self,
161-
checkpoint_id: &Self::CheckpointId,
162-
) -> Result<Option<Checkpoint>, Self::Error> {
163-
self.cache.get_checkpoint(checkpoint_id)
170+
fn ensured_retained_count(&self) -> Result<usize, Self::Error> {
171+
self.cache.ensured_retained_count()
172+
}
173+
174+
fn should_retain(&self, cid: &Self::CheckpointId) -> Result<bool, Self::Error> {
175+
self.cache.should_retain(cid)
176+
}
177+
178+
fn checkpoint_count(&self) -> Result<usize, Self::Error> {
179+
self.cache.checkpoint_count()
164180
}
165181

166182
fn get_checkpoint_at_depth(
@@ -170,19 +186,19 @@ where
170186
self.cache.get_checkpoint_at_depth(checkpoint_depth)
171187
}
172188

173-
fn min_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
174-
self.cache.min_checkpoint_id()
175-
}
176-
177-
fn max_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
178-
self.cache.max_checkpoint_id()
189+
fn get_checkpoint(
190+
&self,
191+
checkpoint_id: &Self::CheckpointId,
192+
) -> Result<Option<Checkpoint>, Self::Error> {
193+
self.cache.get_checkpoint(checkpoint_id)
179194
}
180195

181-
fn with_checkpoints<F>(&mut self, limit: usize, callback: F) -> Result<(), Self::Error>
196+
fn with_checkpoints<F>(&self, limit: usize, mut callback: F) -> Result<(), Self::Error>
182197
where
183198
F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>,
184199
{
185-
self.cache.with_checkpoints(limit, callback)
200+
self.cache
201+
.with_checkpoints(limit, |cid, c| callback(cid, c))
186202
}
187203

188204
fn update_checkpoint_with<F>(
@@ -229,7 +245,7 @@ mod tests {
229245
};
230246

231247
fn check_equal(
232-
mut lhs: MemoryShardStore<String, u64>,
248+
lhs: MemoryShardStore<String, u64>,
233249
rhs: CachingShardStore<MemoryShardStore<String, u64>>,
234250
) {
235251
let rhs = rhs.flush().unwrap();

0 commit comments

Comments
 (0)