diff --git a/icechunk/src/ops/mod.rs b/icechunk/src/ops/mod.rs index 1a004eccb..9dcf3212f 100644 --- a/icechunk/src/ops/mod.rs +++ b/icechunk/src/ops/mod.rs @@ -1,6 +1,8 @@ -use std::{collections::HashSet, future::ready, sync::Arc}; +use std::{collections::HashSet, sync::Arc}; +use async_stream::try_stream; use futures::{Stream, StreamExt as _, TryStreamExt as _, stream}; +use tokio::pin; use tracing::instrument; use crate::{ @@ -21,8 +23,6 @@ pub async fn all_roots<'a>( extra_roots: &'a HashSet, ) -> RefResult> + 'a> { let all_refs = list_refs(storage, storage_settings).await?; - // TODO: this could be optimized by not following the ancestry of snapshots that we have - // already seen let roots = stream::iter(all_refs) .then(move |r| async move { r.fetch(storage, storage_settings).await.map(|ref_data| ref_data.snapshot) @@ -39,21 +39,73 @@ pub async fn pointed_snapshots<'a>( asset_manager: Arc, extra_roots: &'a HashSet, ) -> RepositoryResult> + 'a> { - let roots = all_roots(storage, storage_settings, extra_roots) - .await? - .err_into::(); - Ok(roots - .and_then(move |snap_id| { + let mut seen: HashSet = HashSet::new(); + let res = try_stream! { + let roots = all_roots(storage, storage_settings, extra_roots) + .await? + .err_into::(); + pin!(roots); + + while let Some(pointed_snap_id) = roots.try_next().await? { let asset_manager = Arc::clone(&asset_manager.clone()); - async move { - let snap = asset_manager.fetch_snapshot(&snap_id).await?; - let parents = Arc::clone(&asset_manager) - .snapshot_ancestry(&snap.id()) - .await? - .map_ok(|parent| parent.id) - .err_into(); - Ok(stream::once(ready(Ok(snap_id))).chain(parents)) + if ! seen.contains(&pointed_snap_id) { + let parents = asset_manager.snapshot_ancestry(&pointed_snap_id).await?; + for await parent in parents { + let snap_id = parent?.id; + if seen.insert(snap_id.clone()) { + // it's a new snapshot + yield snap_id + } else { + // as soon as we find a repeated snapshot + // there is no point in continuing to retrieve + // the rest of the ancestry, it must be already + // retrieved from other ref + break + } + } } - }) - .try_flatten()) + } + }; + Ok(res) +} + +#[cfg(test)] +#[allow(clippy::panic, clippy::unwrap_used, clippy::expect_used)] +mod tests { + use futures::TryStreamExt as _; + use std::collections::{HashMap, HashSet}; + + use bytes::Bytes; + + use crate::{ + Repository, format::Path, new_in_memory_storage, ops::pointed_snapshots, + }; + + #[tokio::test] + async fn test_pointed_snapshots_duplicate() -> Result<(), Box> + { + let storage = new_in_memory_storage().await?; + let repo = Repository::create(None, storage.clone(), HashMap::new()).await?; + let mut session = repo.writable_session("main").await?; + session.add_group(Path::root(), Bytes::new()).await?; + let snap = session.commit("commit", None).await?; + repo.create_tag("tag1", &snap).await?; + let mut session = repo.writable_session("main").await?; + session.add_group("/foo".try_into().unwrap(), Bytes::new()).await?; + let snap = session.commit("commit", None).await?; + repo.create_tag("tag2", &snap).await?; + + let all_snaps = pointed_snapshots( + storage.as_ref(), + &storage.default_settings(), + repo.asset_manager().clone(), + &HashSet::new(), + ) + .await? + .try_collect::>() + .await?; + + assert_eq!(all_snaps.len(), 3); + Ok(()) + } }