Skip to content

Commit 55ddaa4

Browse files
committed
Commit learns how to rebase in case of conflict
`Session.commit` accepts new arguments: ``` rebase_with: ConflictSolver | None = None, rebase_tries: int = 5 ``` It uses the `ConflictSolver` to call `Session.rebase` in a loop, up to `rebase_tries` times, and try to commit after each rebase.
1 parent 3e44f49 commit 55ddaa4

File tree

6 files changed

+208
-37
lines changed

6 files changed

+208
-37
lines changed

icechunk-python/examples/bank_accounts.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -138,15 +138,6 @@ def slow_transfer_task(repo: icechunk.Repository) -> WireResult:
138138
return res
139139

140140

141-
def rebase_loop(session: icechunk.Session, message: str) -> None:
142-
while True:
143-
try:
144-
session.commit(message)
145-
return
146-
except icechunk.ConflictError:
147-
session.rebase(icechunk.ConflictDetector())
148-
149-
150141
def rebase_transfer_task(repo: icechunk.Repository) -> WireResult:
151142
"""Safe and fast approach to concurrent transfers.
152143
@@ -169,7 +160,10 @@ def rebase_transfer_task(repo: icechunk.Repository) -> WireResult:
169160
return res
170161
if res == WireResult.DONE:
171162
try:
172-
rebase_loop(session, f"wired ${amount}: {from_account} -> {to_account}")
163+
session.commit(
164+
f"wired ${amount}: {from_account} -> {to_account}",
165+
rebase_with=icechunk.ConflictDetector(),
166+
)
173167
return WireResult.DONE
174168
except icechunk.RebaseFailedError:
175169
pass

icechunk-python/python/icechunk/_icechunk_python.pyi

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,13 @@ class PySession:
10661066
@property
10671067
def store(self) -> PyStore: ...
10681068
def merge(self, other: PySession) -> None: ...
1069-
def commit(self, message: str, metadata: dict[str, Any] | None = None) -> str: ...
1069+
def commit(
1070+
self,
1071+
message: str,
1072+
metadata: dict[str, Any] | None = None,
1073+
rebase_with: ConflictSolver | None = None,
1074+
rebase_tries: int = 5,
1075+
) -> str: ...
10701076
def rebase(self, solver: ConflictSolver) -> None: ...
10711077

10721078
class PyStore:

icechunk-python/python/icechunk/session.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,13 @@ def merge(self, other: Self) -> None:
250250
"""
251251
self._session.merge(other._session)
252252

253-
def commit(self, message: str, metadata: dict[str, Any] | None = None) -> str:
253+
def commit(
254+
self,
255+
message: str,
256+
metadata: dict[str, Any] | None = None,
257+
rebase_with: ConflictSolver | None = None,
258+
rebase_tries: int = 5,
259+
) -> str:
254260
"""
255261
Commit the changes in the session to the repository.
256262
@@ -264,6 +270,10 @@ def commit(self, message: str, metadata: dict[str, Any] | None = None) -> str:
264270
The message to write with the commit.
265271
metadata : dict[str, Any] | None, optional
266272
Additional metadata to store with the commit snapshot.
273+
rebase_with : ConflictSolver | None, optional
274+
If other session committed while the current session was writing, use Session.rebase with this solver.
275+
rebase_tries : int, optional
276+
If other session committed while the current session was writing, use Session.rebase up to this many times in a loop.
267277
268278
Returns
269279
-------
@@ -276,9 +286,13 @@ def commit(self, message: str, metadata: dict[str, Any] | None = None) -> str:
276286
If the session is out of date and a conflict occurs.
277287
"""
278288
try:
279-
return self._session.commit(message, metadata)
289+
return self._session.commit(
290+
message, metadata, rebase_with=rebase_with, rebase_tries=rebase_tries
291+
)
280292
except PyConflictError as e:
281293
raise ConflictError(e) from None
294+
except PyRebaseFailedError as e:
295+
raise RebaseFailedError(e) from None
282296

283297
def rebase(self, solver: ConflictSolver) -> None:
284298
"""

icechunk-python/src/session.rs

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -180,23 +180,35 @@ impl PySession {
180180
})
181181
}
182182

183-
#[pyo3(signature = (message, metadata=None))]
183+
#[pyo3(signature = (message, metadata=None, rebase_with=None, rebase_tries=5))]
184184
pub fn commit(
185185
&self,
186186
py: Python<'_>,
187187
message: &str,
188188
metadata: Option<PySnapshotProperties>,
189+
rebase_with: Option<PyConflictSolver>,
190+
rebase_tries: Option<u16>,
189191
) -> PyResult<String> {
192+
let metadata = metadata.map(|m| m.into());
190193
// This is blocking function, we need to release the Gil
191194
py.allow_threads(move || {
192195
pyo3_async_runtimes::tokio::get_runtime().block_on(async {
193-
let snapshot_id = self
194-
.0
195-
.write()
196-
.await
197-
.commit(message, metadata.map(|m| m.into()))
198-
.await
199-
.map_err(PyIcechunkStoreError::SessionError)?;
196+
let mut session = self.0.write().await;
197+
let snapshot_id = if let Some(solver) = rebase_with {
198+
session
199+
.commit_rebasing(
200+
solver.as_ref(),
201+
rebase_tries.unwrap_or(5),
202+
message,
203+
metadata,
204+
|_| async {},
205+
|_| async {},
206+
)
207+
.await
208+
} else {
209+
session.commit(message, metadata).await
210+
}
211+
.map_err(PyIcechunkStoreError::SessionError)?;
200212
Ok(snapshot_id.to_string())
201213
})
202214
})

icechunk-python/tests/test_conflicts.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@ def test_rebase_no_conflicts(repo: icechunk.Repository) -> None:
8383
array_b = cast(zarr.Array, root_b["foo/bar/some-array"])
8484
array_b.attrs["repo"] = 2
8585

86-
session_b.rebase(icechunk.ConflictDetector())
87-
session_b.commit("update array")
86+
session_b.commit("update array", rebase_with=icechunk.ConflictDetector())
8887

8988
session_c = repo.readonly_session(branch="main")
9089
store_c = session_c.store
@@ -94,13 +93,7 @@ def test_rebase_no_conflicts(repo: icechunk.Repository) -> None:
9493
assert array_c.attrs["repo"] == 2
9594

9695

97-
@pytest.mark.parametrize(
98-
"on_chunk_conflict",
99-
[icechunk.VersionSelection.UseOurs, icechunk.VersionSelection.UseTheirs],
100-
)
101-
def test_rebase_fails_on_user_atts_double_edit(
102-
repo: icechunk.Repository, on_chunk_conflict: icechunk.VersionSelection
103-
) -> None:
96+
def test_rebase_fails_on_user_atts_double_edit(repo: icechunk.Repository) -> None:
10497
session_a = repo.writable_session("main")
10598
session_b = repo.writable_session("main")
10699
store_a = session_a.store
@@ -120,7 +113,7 @@ def test_rebase_fails_on_user_atts_double_edit(
120113

121114
# Make sure it fails if the resolver is not set
122115
with pytest.raises(icechunk.RebaseFailedError):
123-
session_b.rebase(icechunk.BasicConflictSolver())
116+
session_b.commit("update array", rebase_with=icechunk.BasicConflictSolver())
124117

125118

126119
@pytest.mark.parametrize(
@@ -150,10 +143,11 @@ def test_rebase_chunks_with_ours(
150143
# Make sure it fails if the resolver is not set
151144
with pytest.raises(icechunk.RebaseFailedError):
152145
try:
153-
session_b.rebase(
154-
icechunk.BasicConflictSolver(
146+
session_b.commit(
147+
"update first column of array",
148+
rebase_with=icechunk.BasicConflictSolver(
155149
on_chunk_conflict=icechunk.VersionSelection.Fail
156-
)
150+
),
157151
)
158152
except icechunk.RebaseFailedError as e:
159153
assert e.conflicts[0].path == "/foo/bar/some-array"
@@ -184,8 +178,7 @@ def test_rebase_chunks_with_ours(
184178
on_chunk_conflict=on_chunk_conflict,
185179
)
186180

187-
session_b.rebase(solver)
188-
session_b.commit("after conflict")
181+
session_b.commit("after conflict", rebase_with=solver)
189182

190183
session_c = repo.readonly_session(branch="main")
191184
store_c = session_c.store

icechunk/src/session.rs

Lines changed: 153 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,38 @@ impl Session {
877877
Ok(id)
878878
}
879879

880+
pub async fn commit_rebasing<F1, F2, Fut1, Fut2>(
881+
&mut self,
882+
solver: &dyn ConflictSolver,
883+
rebase_attempts: u16,
884+
message: &str,
885+
properties: Option<SnapshotProperties>,
886+
// We would prefer to make this argument optional, but passing None
887+
// for this argument is so hard. Callers should just pass noop closure like
888+
// |_| async {},
889+
before_rebase: F1,
890+
after_rebase: F2,
891+
) -> SessionResult<SnapshotId>
892+
where
893+
F1: Fn(u16) -> Fut1,
894+
F2: Fn(u16) -> Fut2,
895+
Fut1: Future<Output = ()>,
896+
Fut2: Future<Output = ()>,
897+
{
898+
for attempt in 0..rebase_attempts {
899+
match self.commit(message, properties.clone()).await {
900+
Ok(snap) => return Ok(snap),
901+
Err(SessionError { kind: SessionErrorKind::Conflict { .. }, .. }) => {
902+
before_rebase(attempt + 1).await;
903+
self.rebase(solver).await?;
904+
after_rebase(attempt + 1).await;
905+
}
906+
Err(other_err) => return Err(other_err),
907+
}
908+
}
909+
self.commit(message, properties).await
910+
}
911+
880912
/// Detect and optionally fix conflicts between the current [`ChangeSet`] (or session) and
881913
/// the tip of the branch.
882914
///
@@ -1754,7 +1786,11 @@ fn aggregate_extents<'a, T: std::fmt::Debug, E>(
17541786
#[cfg(test)]
17551787
#[allow(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
17561788
mod tests {
1757-
use std::{collections::HashMap, error::Error};
1789+
use std::{
1790+
collections::HashMap,
1791+
error::Error,
1792+
sync::atomic::{AtomicU16, Ordering},
1793+
};
17581794

17591795
use crate::{
17601796
ObjectStorage, Repository,
@@ -3752,6 +3788,122 @@ mod tests {
37523788
Ok(())
37533789
}
37543790

3791+
#[tokio_test]
3792+
/// Tests `commit_rebasing` retries the proper number of times when there are conflicts
3793+
async fn test_commit_rebasing_attempts() -> Result<(), Box<dyn Error>> {
3794+
let repo = Arc::new(create_memory_store_repository().await);
3795+
let mut session = repo.writable_session("main").await?;
3796+
session
3797+
.add_array("/array".try_into().unwrap(), basic_shape(), None, Bytes::new())
3798+
.await?;
3799+
session.commit("create array", None).await?;
3800+
3801+
// This is the main session we'll be trying to commit (and rebase)
3802+
let mut session = repo.writable_session("main").await?;
3803+
let path: Path = "/array".try_into().unwrap();
3804+
session
3805+
.set_chunk_ref(
3806+
path.clone(),
3807+
ChunkIndices(vec![1]),
3808+
Some(ChunkPayload::Inline("repo 1".into())),
3809+
)
3810+
.await?;
3811+
3812+
// we create an initial conflict for commit
3813+
let mut session2 = repo.writable_session("main").await.unwrap();
3814+
let path: Path = "/array".try_into().unwrap();
3815+
session2
3816+
.set_chunk_ref(
3817+
path.clone(),
3818+
ChunkIndices(vec![2]),
3819+
Some(ChunkPayload::Inline("repo 1".into())),
3820+
)
3821+
.await
3822+
.unwrap();
3823+
session2.commit("conflicting", None).await.unwrap();
3824+
3825+
let repo_ref = &repo;
3826+
let attempts = AtomicU16::new(0);
3827+
let attempts_ref = &attempts;
3828+
3829+
// after each rebase attempt we'll run this closure that creates a new conflict
3830+
// the result should be that it can never commit, failing after the indicated number of
3831+
// attempts
3832+
let conflicting = |attempt| async move {
3833+
attempts_ref.fetch_add(1, Ordering::SeqCst); //*attempts_ref = *attempts_ref + 1;;
3834+
assert_eq!(attempt, attempts_ref.load(Ordering::SeqCst));
3835+
3836+
let repo_c = Arc::clone(repo_ref);
3837+
let mut s = repo_c.writable_session("main").await.unwrap();
3838+
s.set_chunk_ref(
3839+
"/array".try_into().unwrap(),
3840+
ChunkIndices(vec![2]),
3841+
Some(ChunkPayload::Inline("repo 1".into())),
3842+
)
3843+
.await
3844+
.unwrap();
3845+
s.commit("conflicting", None).await.unwrap();
3846+
};
3847+
3848+
let res = session
3849+
.commit_rebasing(
3850+
&ConflictDetector,
3851+
3,
3852+
"updated non-conflict chunk",
3853+
None,
3854+
|_| async {},
3855+
conflicting,
3856+
)
3857+
.await;
3858+
3859+
// It has to give up eventually
3860+
assert!(matches!(
3861+
res,
3862+
Err(SessionError { kind: SessionErrorKind::Conflict { .. }, .. })
3863+
));
3864+
3865+
// It has to rebase 3 times
3866+
assert_eq!(attempts.into_inner(), 3);
3867+
3868+
let attempts = AtomicU16::new(0);
3869+
let attempts_ref = &attempts;
3870+
3871+
// now we'll create a new conflict twice, and finally do nothing so the commit can succeed
3872+
let conflicting_twice = |attempt| async move {
3873+
attempts_ref.fetch_add(1, Ordering::SeqCst); //*attempts_ref = *attempts_ref + 1;;
3874+
assert_eq!(attempt, attempts_ref.load(Ordering::SeqCst));
3875+
if attempt <= 2 {
3876+
let repo_c = Arc::clone(repo_ref);
3877+
3878+
let mut s = repo_c.writable_session("main").await.unwrap();
3879+
s.set_chunk_ref(
3880+
"/array".try_into().unwrap(),
3881+
ChunkIndices(vec![2]),
3882+
Some(ChunkPayload::Inline("repo 1".into())),
3883+
)
3884+
.await
3885+
.unwrap();
3886+
s.commit("conflicting", None).await.unwrap();
3887+
}
3888+
};
3889+
3890+
let res = session
3891+
.commit_rebasing(
3892+
&ConflictDetector,
3893+
42,
3894+
"updated non-conflict chunk",
3895+
None,
3896+
|_| async {},
3897+
conflicting_twice,
3898+
)
3899+
.await;
3900+
3901+
// The commit has to work after 3 rebase attempts
3902+
assert!(res.is_ok());
3903+
assert_eq!(attempts.into_inner(), 3);
3904+
Ok(())
3905+
}
3906+
37553907
#[cfg(test)]
37563908
mod state_machine_test {
37573909
use crate::format::Path;

0 commit comments

Comments
 (0)