Skip to content

Commit eda9cf4

Browse files
authored
Safer reset_branch (#1259)
An optional third argument was added to `reset_branch`: ```python def reset_branch(self, branch: str, snapshot_id: str, *, from_snapshot_id: str | None = None) -> None: ``` If `from_snapshot_id` is passed, the operation raises an exception and leaves the branch untouched if it's not currently pointing to that snapshot id.
1 parent 09329fb commit eda9cf4

File tree

7 files changed

+84
-23
lines changed

7 files changed

+84
-23
lines changed

docs/docs/icechunk-for-git-users.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,22 @@ In git, you can reset a branch to previous commit. Similarly, in Icechunk you ca
7171
repo.reset_branch("my-new-branch", "198273178639187")
7272
```
7373

74+
At this point, the tip of the branch is now the snapshot `198273178639187` and any changes made to the branch will be based on this snapshot. This also means the history of the branch is now same as the ancestry of this snapshot.
75+
7476
!!! warning
7577
This is a destructive operation. It will overwrite the branch reference with the snapshot immediately. It can only be undone by resetting the branch again.
7678

77-
At this point, the tip of the branch is now the snapshot `198273178639187` and any changes made to the branch will be based on this snapshot. This also means the history of the branch is now same as the ancestry of this snapshot.
79+
To make `reset_branch` less dangerous, you can pass an optional third argument, the snapshot id currently pointed
80+
by the branch:
81+
82+
```python
83+
# reset the branch to the initial commit
84+
repo.reset_branch("my-new-branch", "1CECHNKREP0F1RSTCMT0", from_snapshot_id="198273178639187")
85+
```
86+
87+
If the branch is not currently pointing to snapshot `198273178639187`, the operation will be rejected
88+
with an exception. Using this approach, you can make sure to only reset a branch if no ether commits
89+
were done to it since you last checked.
7890

7991
### Branch History
8092

icechunk-python/python/icechunk/_icechunk_python.pyi

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,8 +1395,12 @@ class PyRepository:
13951395
async def lookup_branch_async(self, branch: str) -> str: ...
13961396
def lookup_snapshot(self, snapshot_id: str) -> SnapshotInfo: ...
13971397
async def lookup_snapshot_async(self, snapshot_id: str) -> SnapshotInfo: ...
1398-
def reset_branch(self, branch: str, snapshot_id: str) -> None: ...
1399-
async def reset_branch_async(self, branch: str, snapshot_id: str) -> None: ...
1398+
def reset_branch(
1399+
self, branch: str, to_snapshot_id: str, from_snapshot_id: str | None
1400+
) -> None: ...
1401+
async def reset_branch_async(
1402+
self, branch: str, to_snapshot_id: str, from_snapshot_id: str | None
1403+
) -> None: ...
14001404
def delete_branch(self, branch: str) -> None: ...
14011405
async def delete_branch_async(self, branch: str) -> None: ...
14021406
def delete_tag(self, tag: str) -> None: ...

icechunk-python/python/icechunk/repository.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,9 @@ async def lookup_snapshot_async(self, snapshot_id: str) -> SnapshotInfo:
681681
"""
682682
return await self._repository.lookup_snapshot_async(snapshot_id)
683683

684-
def reset_branch(self, branch: str, snapshot_id: str) -> None:
684+
def reset_branch(
685+
self, branch: str, snapshot_id: str, *, from_snapshot_id: str | None = None
686+
) -> None:
685687
"""
686688
Reset a branch to a specific snapshot.
687689
@@ -694,14 +696,19 @@ def reset_branch(self, branch: str, snapshot_id: str) -> None:
694696
The branch to reset.
695697
snapshot_id : str
696698
The snapshot ID to reset the branch to.
699+
from_snapshot_id : str | None
700+
If passed, the reset will only be executed if the branch currently
701+
points to from_snapshot_id.
697702
698703
Returns
699704
-------
700705
None
701706
"""
702-
self._repository.reset_branch(branch, snapshot_id)
707+
self._repository.reset_branch(branch, snapshot_id, from_snapshot_id)
703708

704-
async def reset_branch_async(self, branch: str, snapshot_id: str) -> None:
709+
async def reset_branch_async(
710+
self, branch: str, snapshot_id: str, *, from_snapshot_id: str | None = None
711+
) -> None:
705712
"""
706713
Reset a branch to a specific snapshot (async version).
707714
@@ -714,12 +721,15 @@ async def reset_branch_async(self, branch: str, snapshot_id: str) -> None:
714721
The branch to reset.
715722
snapshot_id : str
716723
The snapshot ID to reset the branch to.
724+
from_snapshot_id : str | None
725+
If passed, the reset will only be executed if the branch currently
726+
points to from_snapshot_id.
717727
718728
Returns
719729
-------
720730
None
721731
"""
722-
await self._repository.reset_branch_async(branch, snapshot_id)
732+
await self._repository.reset_branch_async(branch, snapshot_id, from_snapshot_id)
723733

724734
def delete_branch(self, branch: str) -> None:
725735
"""

icechunk-python/src/repository.rs

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -973,21 +973,33 @@ impl PyRepository {
973973
&self,
974974
py: Python<'_>,
975975
branch_name: &str,
976-
snapshot_id: &str,
976+
to_snapshot_id: &str,
977+
from_snapshot_id: Option<&str>,
977978
) -> PyResult<()> {
978979
// This function calls block_on, so we need to allow other thread python to make progress
979980
py.allow_threads(move || {
980-
let snapshot_id = SnapshotId::try_from(snapshot_id).map_err(|_| {
981+
let to_snapshot_id = SnapshotId::try_from(to_snapshot_id).map_err(|_| {
981982
PyIcechunkStoreError::RepositoryError(
982-
RepositoryErrorKind::InvalidSnapshotId(snapshot_id.to_owned()).into(),
983+
RepositoryErrorKind::InvalidSnapshotId(to_snapshot_id.to_owned())
984+
.into(),
983985
)
984986
})?;
985987

988+
let from_snapshot_id = from_snapshot_id
989+
.map(|sid| {
990+
SnapshotId::try_from(sid).map_err(|_| {
991+
PyIcechunkStoreError::RepositoryError(
992+
RepositoryErrorKind::InvalidSnapshotId(sid.to_owned()).into(),
993+
)
994+
})
995+
})
996+
.transpose()?;
997+
986998
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
987999
self.0
9881000
.read()
9891001
.await
990-
.reset_branch(branch_name, &snapshot_id)
1002+
.reset_branch(branch_name, &to_snapshot_id, from_snapshot_id.as_ref())
9911003
.await
9921004
.map_err(PyIcechunkStoreError::RepositoryError)?;
9931005
Ok(())
@@ -999,20 +1011,31 @@ impl PyRepository {
9991011
&'py self,
10001012
py: Python<'py>,
10011013
branch_name: &str,
1002-
snapshot_id: &str,
1014+
to_snapshot_id: &str,
1015+
from_snapshot_id: Option<&str>,
10031016
) -> PyResult<Bound<'py, PyAny>> {
10041017
let repository = self.0.clone();
10051018
let branch_name = branch_name.to_owned();
1006-
let snapshot_id = SnapshotId::try_from(snapshot_id).map_err(|_| {
1019+
let to_snapshot_id = SnapshotId::try_from(to_snapshot_id).map_err(|_| {
10071020
PyIcechunkStoreError::RepositoryError(
1008-
RepositoryErrorKind::InvalidSnapshotId(snapshot_id.to_owned()).into(),
1021+
RepositoryErrorKind::InvalidSnapshotId(to_snapshot_id.to_owned()).into(),
10091022
)
10101023
})?;
10111024

1025+
let from_snapshot_id = from_snapshot_id
1026+
.map(|sid| {
1027+
SnapshotId::try_from(sid).map_err(|_| {
1028+
PyIcechunkStoreError::RepositoryError(
1029+
RepositoryErrorKind::InvalidSnapshotId(sid.to_owned()).into(),
1030+
)
1031+
})
1032+
})
1033+
.transpose()?;
1034+
10121035
pyo3_async_runtimes::tokio::future_into_py(py, async move {
10131036
let repository = repository.read().await;
10141037
repository
1015-
.reset_branch(&branch_name, &snapshot_id)
1038+
.reset_branch(&branch_name, &to_snapshot_id, from_snapshot_id.as_ref())
10161039
.await
10171040
.map_err(PyIcechunkStoreError::RepositoryError)?;
10181041
Ok(())

icechunk-python/tests/test_timetravel.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,13 +232,21 @@ async def test_branch_reset() -> None:
232232

233233
group = zarr.open_group(store=store)
234234
group.create_group("b")
235-
session.commit("group b")
235+
last_commit = session.commit("group b")
236236

237237
keys = {k async for k in store.list()}
238238
assert "a/zarr.json" in keys
239239
assert "b/zarr.json" in keys
240240

241-
repo.reset_branch("main", prev_snapshot_id)
241+
with pytest.raises(ic.IcechunkError, match="branch update conflict"):
242+
repo.reset_branch(
243+
"main", prev_snapshot_id, from_snapshot_id="1CECHNKREP0F1RSTCMT0"
244+
)
245+
246+
assert last_commit == repo.lookup_branch("main")
247+
248+
repo.reset_branch("main", prev_snapshot_id, from_snapshot_id=last_commit)
249+
assert prev_snapshot_id == repo.lookup_branch("main")
242250

243251
session = repo.readonly_session("main")
244252
store = session.store

icechunk/src/repository.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,8 @@ impl Repository {
574574
pub async fn reset_branch(
575575
&self,
576576
branch: &str,
577-
snapshot_id: &SnapshotId,
577+
to_snapshot_id: &SnapshotId,
578+
from_snapshot_id: Option<&SnapshotId>,
578579
) -> RepositoryResult<()> {
579580
if !self.storage.can_write() {
580581
return Err(RepositoryErrorKind::ReadonlyStorage(
@@ -585,16 +586,19 @@ impl Repository {
585586
raise_if_invalid_snapshot_id(
586587
self.storage.as_ref(),
587588
&self.storage_settings,
588-
snapshot_id,
589+
to_snapshot_id,
589590
)
590591
.await?;
591-
let branch_tip = self.lookup_branch(branch).await?;
592+
let branch_tip = match from_snapshot_id {
593+
Some(snap) => snap,
594+
None => &self.lookup_branch(branch).await?,
595+
};
592596
update_branch(
593597
self.storage.as_ref(),
594598
&self.storage_settings,
595599
branch,
596-
snapshot_id.clone(),
597-
Some(&branch_tip),
600+
to_snapshot_id.clone(),
601+
Some(branch_tip),
598602
)
599603
.await
600604
.err_into()

icechunk/src/store.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2348,7 +2348,7 @@ mod tests {
23482348
assert!(store.exists("a/zarr.json").await?);
23492349
assert!(store.exists("b/zarr.json").await?);
23502350

2351-
repo.reset_branch("main", &prev_snap).await?;
2351+
repo.reset_branch("main", &prev_snap, None).await?;
23522352
let ds = Arc::new(RwLock::new(
23532353
repo.readonly_session(&VersionInfo::BranchTipRef("main".to_string())).await?,
23542354
));

0 commit comments

Comments
 (0)