diff --git a/icechunk-python/python/icechunk/credentials.py b/icechunk-python/python/icechunk/credentials.py index e3a603095..acb331597 100644 --- a/icechunk-python/python/icechunk/credentials.py +++ b/icechunk-python/python/icechunk/credentials.py @@ -1,7 +1,9 @@ +import asyncio +import inspect import pickle -from collections.abc import Callable, Mapping +from collections.abc import Awaitable, Callable, Coroutine, Mapping from datetime import datetime -from typing import cast +from typing import Any, TypeVar, cast from icechunk._icechunk_python import ( AzureCredentials, @@ -46,16 +48,50 @@ AnyCredential = Credentials.S3 | Credentials.Gcs | Credentials.Azure +S3CredentialProvider = Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] +GcsCredentialProvider = Callable[[], GcsBearerCredential | Awaitable[GcsBearerCredential]] +CredentialType = TypeVar("CredentialType") + + +def _resolve_callback_once( + get_credentials: Callable[[], CredentialType | Awaitable[CredentialType]], +) -> CredentialType: + value = get_credentials() + if inspect.isawaitable(value): + if inspect.iscoroutine(value): + coroutine = cast(Coroutine[Any, Any, CredentialType], value) + else: + awaitable = cast(Awaitable[CredentialType], value) + + async def _as_coroutine() -> CredentialType: + return await awaitable + + coroutine = _as_coroutine() + + try: + return asyncio.run(coroutine) + except RuntimeError as err: + coroutine.close() + if "asyncio.run() cannot be called from a running event loop" in str(err): + raise ValueError( + "scatter_initial_credentials=True cannot eagerly evaluate async " + "credential callbacks while an event loop is running. " + "Set scatter_initial_credentials=False in async contexts." + ) from err + raise + + return value + def s3_refreshable_credentials( - get_credentials: Callable[[], S3StaticCredentials], + get_credentials: S3CredentialProvider, scatter_initial_credentials: bool = False, ) -> S3Credentials.Refreshable: """Create refreshable credentials for S3 and S3 compatible object stores. Parameters ---------- - get_credentials: Callable[[], S3StaticCredentials] + get_credentials: Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] Use this function to get and refresh the credentials. The function must be pickable. scatter_initial_credentials: bool, optional Immediately call and store the value returned by get_credentials. This is useful if the @@ -64,7 +100,9 @@ def s3_refreshable_credentials( set of credentials has expired, the cached value is no longer used. Notice that credentials obtained are stored, and they can be sent over the network if you pickle the session/repo. """ - current = get_credentials() if scatter_initial_credentials else None + current = ( + _resolve_callback_once(get_credentials) if scatter_initial_credentials else None + ) return S3Credentials.Refreshable(pickle.dumps(get_credentials), current) @@ -116,7 +154,7 @@ def s3_credentials( expires_after: datetime | None = None, anonymous: bool | None = None, from_env: bool | None = None, - get_credentials: Callable[[], S3StaticCredentials] | None = None, + get_credentials: S3CredentialProvider | None = None, scatter_initial_credentials: bool = False, ) -> AnyS3Credential: """Create credentials for S3 and S3 compatible object stores. @@ -137,7 +175,7 @@ def s3_credentials( If set to True requests to the object store will not be signed from_env: bool | None Fetch credentials from the operative system environment - get_credentials: Callable[[], S3StaticCredentials] | None + get_credentials: Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None Use this function to get and refresh object store credentials scatter_initial_credentials: bool, optional Immediately call and store the value returned by get_credentials. This is useful if the @@ -218,14 +256,14 @@ def gcs_static_credentials( def gcs_refreshable_credentials( - get_credentials: Callable[[], GcsBearerCredential], + get_credentials: GcsCredentialProvider, scatter_initial_credentials: bool = False, ) -> GcsCredentials.Refreshable: """Create refreshable credentials for Google Cloud Storage object store. Parameters ---------- - get_credentials: Callable[[], S3StaticCredentials] + get_credentials: Callable[[], GcsBearerCredential | Awaitable[GcsBearerCredential]] Use this function to get and refresh the credentials. The function must be pickable. scatter_initial_credentials: bool, optional Immediately call and store the value returned by get_credentials. This is useful if the @@ -235,7 +273,9 @@ def gcs_refreshable_credentials( obtained are stored, and they can be sent over the network if you pickle the session/repo. """ - current = get_credentials() if scatter_initial_credentials else None + current = ( + _resolve_callback_once(get_credentials) if scatter_initial_credentials else None + ) return GcsCredentials.Refreshable(pickle.dumps(get_credentials), current) @@ -257,7 +297,7 @@ def gcs_credentials( bearer_token: str | None = None, from_env: bool | None = None, anonymous: bool | None = None, - get_credentials: Callable[[], GcsBearerCredential] | None = None, + get_credentials: GcsCredentialProvider | None = None, scatter_initial_credentials: bool = False, ) -> AnyGcsCredential: """Create credentials Google Cloud Storage object store. diff --git a/icechunk-python/python/icechunk/storage.py b/icechunk-python/python/icechunk/storage.py index dbe0007e1..d66bc1c8a 100644 --- a/icechunk-python/python/icechunk/storage.py +++ b/icechunk-python/python/icechunk/storage.py @@ -1,4 +1,4 @@ -from collections.abc import Callable +from collections.abc import Awaitable, Callable from datetime import datetime from icechunk._icechunk_python import ( @@ -125,7 +125,9 @@ def s3_storage( expires_after: datetime | None = None, anonymous: bool | None = None, from_env: bool | None = None, - get_credentials: Callable[[], S3StaticCredentials] | None = None, + get_credentials: ( + Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None + ) = None, scatter_initial_credentials: bool = False, force_path_style: bool = False, network_stream_timeout_seconds: int = 60, @@ -157,7 +159,7 @@ def s3_storage( If set to True requests to the object store will not be signed from_env: bool | None Fetch credentials from the operative system environment - get_credentials: Callable[[], S3StaticCredentials] | None + get_credentials: Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None Use this function to get and refresh object store credentials scatter_initial_credentials: bool, optional Immediately call and store the value returned by get_credentials. This is useful if the @@ -254,7 +256,9 @@ def tigris_storage( expires_after: datetime | None = None, anonymous: bool | None = None, from_env: bool | None = None, - get_credentials: Callable[[], S3StaticCredentials] | None = None, + get_credentials: ( + Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None + ) = None, scatter_initial_credentials: bool = False, network_stream_timeout_seconds: int = 60, ) -> Storage: @@ -288,7 +292,7 @@ def tigris_storage( If set to True requests to the object store will not be signed from_env: bool | None Fetch credentials from the operative system environment - get_credentials: Callable[[], S3StaticCredentials] | None + get_credentials: Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None Use this function to get and refresh object store credentials scatter_initial_credentials: bool, optional Immediately call and store the value returned by get_credentials. This is useful if the @@ -340,7 +344,9 @@ def r2_storage( expires_after: datetime | None = None, anonymous: bool | None = None, from_env: bool | None = None, - get_credentials: Callable[[], S3StaticCredentials] | None = None, + get_credentials: ( + Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None + ) = None, scatter_initial_credentials: bool = False, network_stream_timeout_seconds: int = 60, ) -> Storage: @@ -374,7 +380,7 @@ def r2_storage( If set to True requests to the object store will not be signed from_env: bool | None Fetch credentials from the operative system environment - get_credentials: Callable[[], S3StaticCredentials] | None + get_credentials: Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None Use this function to get and refresh object store credentials scatter_initial_credentials: bool, optional Immediately call and store the value returned by get_credentials. This is useful if the @@ -436,7 +442,9 @@ def gcs_storage( anonymous: bool | None = None, from_env: bool | None = None, config: dict[str, str] | None = None, - get_credentials: Callable[[], GcsBearerCredential] | None = None, + get_credentials: ( + Callable[[], GcsBearerCredential | Awaitable[GcsBearerCredential]] | None + ) = None, scatter_initial_credentials: bool = False, ) -> Storage: """Create a Storage instance that saves data in Google Cloud Storage object store. @@ -461,7 +469,7 @@ def gcs_storage( Fetch credentials from the operative system environment config: dict[str, str] | None A dictionary of options for the Google Cloud Storage object store. See https://docs.rs/object_store/latest/object_store/gcp/enum.GoogleConfigKey.html#variants for a list of possible configuration keys. - get_credentials: Callable[[], GcsBearerCredential] | None + get_credentials: Callable[[], GcsBearerCredential | Awaitable[GcsBearerCredential]] | None Use this function to get and refresh object store credentials scatter_initial_credentials: bool, optional Immediately call and store the value returned by get_credentials. This is useful if the diff --git a/icechunk-python/src/asyncio_bridge.rs b/icechunk-python/src/asyncio_bridge.rs new file mode 100644 index 000000000..99e596386 --- /dev/null +++ b/icechunk-python/src/asyncio_bridge.rs @@ -0,0 +1,64 @@ +use pyo3::{ + PyErr, PyResult, Python, + exceptions::PyRuntimeError, + types::{PyAnyMethods, PyModule}, +}; +use std::sync::OnceLock; + +pub(crate) fn current_task_locals() -> Option { + Python::attach(|py| pyo3_async_runtimes::tokio::get_current_locals(py).ok()) +} + +static FALLBACK_TASK_LOCALS: OnceLock> = + OnceLock::new(); + +fn create_fallback_task_locals() -> Result { + let (tx, rx) = std::sync::mpsc::sync_channel::< + Result, + >(1); + + std::thread::Builder::new() + .name("icechunk-python-asyncio-fallback".to_string()) + .spawn(move || { + let mut tx = Some(tx); + let run_result = Python::attach(|py| -> PyResult<()> { + let asyncio = PyModule::import(py, "asyncio")?; + let event_loop = asyncio.getattr("new_event_loop")?.call0()?; + asyncio.getattr("set_event_loop")?.call1((event_loop.clone(),))?; + + let locals = pyo3_async_runtimes::TaskLocals::new(event_loop.clone()); + if let Some(sender) = tx.take() { + let _ = sender.send(Ok(locals)); + } else { + return Err(PyRuntimeError::new_err( + "fallback asyncio loop sender was unexpectedly missing", + )); + } + + event_loop.call_method0("run_forever")?; + Ok(()) + }); + + if let Err(err) = run_result + && let Some(sender) = tx + { + let _ = sender.send(Err(err.to_string())); + } + }) + .map_err(|err| format!("failed to start fallback asyncio loop thread: {err}"))?; + + match rx.recv() { + Ok(Ok(locals)) => Ok(locals), + Ok(Err(err)) => Err(format!("failed to initialize fallback asyncio loop: {err}")), + Err(err) => Err(format!( + "fallback asyncio loop thread terminated before initialization: {err}" + )), + } +} + +pub(crate) fn fallback_task_locals() -> Result { + match FALLBACK_TASK_LOCALS.get_or_init(create_fallback_task_locals) { + Ok(locals) => Ok(locals.clone()), + Err(err) => Err(PyRuntimeError::new_err(err.clone())), + } +} diff --git a/icechunk-python/src/config.rs b/icechunk-python/src/config.rs index 68e80a0c1..929d4680e 100644 --- a/icechunk-python/src/config.rs +++ b/icechunk-python/src/config.rs @@ -29,10 +29,14 @@ use icechunk::{ }; use pyo3::{ Bound, FromPyObject, Py, PyErr, PyResult, Python, pyclass, pymethods, - types::{PyAnyMethods, PyModule, PyType}, + types::{PyAny, PyAnyMethods, PyModule, PyType}, }; -use crate::errors::PyIcechunkStoreError; +use crate::{ + asyncio_bridge::{current_task_locals, fallback_task_locals}, + errors::PyIcechunkStoreError, + sync::would_deadlock_current_loop, +}; #[pyclass(name = "S3StaticCredentials")] #[derive(Clone, Debug)] @@ -141,25 +145,48 @@ pub(crate) fn datetime_repr(d: &DateTime) -> String { struct PythonCredentialsFetcher { pub pickled_function: Vec, pub initial: Option, + // Intentionally skipped during serialization — on deserialization, + // `await_python_callback` calls `current_task_locals()` to get + // fresh locals from whatever event loop is active at that point. + #[serde(skip, default)] + pub task_locals: Option, } impl PythonCredentialsFetcher { fn new(pickled_function: Vec) -> Self { - PythonCredentialsFetcher { pickled_function, initial: None } + PythonCredentialsFetcher { + pickled_function, + initial: None, + task_locals: current_task_locals(), + } } fn new_with_initial(pickled_function: Vec, current: C) -> Self where C: Into, { - PythonCredentialsFetcher { pickled_function, initial: Some(current.into()) } + PythonCredentialsFetcher { + pickled_function, + initial: Some(current.into()), + task_locals: current_task_locals(), + } } } +enum PythonCallbackResult { + Value(PyCred), + Awaitable(Py), +} + +fn is_awaitable(py: Python<'_>, value: &Bound<'_, PyAny>) -> Result { + let inspect = PyModule::import(py, "inspect")?; + inspect.getattr("isawaitable")?.call1((value,))?.extract() +} + fn call_pickled( py: Python<'_>, - pickled_function: Vec, -) -> Result + pickled_function: &[u8], +) -> Result, PyErr> where PyCred: for<'a, 'py> FromPyObject<'a, 'py>, for<'a, 'py> >::Error: Into, @@ -167,8 +194,48 @@ where let pickle_module = PyModule::import(py, "pickle")?; let loads_function = pickle_module.getattr("loads")?; let fetcher = loads_function.call1((pickled_function,))?; - let creds: PyCred = fetcher.call0()?.extract().map_err(Into::into)?; - Ok(creds) + let value = fetcher.call0()?; + let value_for_extract = value.clone(); + let extract_err = match value_for_extract.extract::() { + Ok(creds) => return Ok(PythonCallbackResult::Value(creds)), + Err(extract_err) => extract_err, + }; + + if is_awaitable(py, &value)? { + Ok(PythonCallbackResult::Awaitable(value.unbind())) + } else { + Err(extract_err.into()) + } +} + +async fn await_python_callback( + awaitable: Py, + task_locals: Option, +) -> Result +where + PyCred: for<'a, 'py> FromPyObject<'a, 'py>, + for<'a, 'py> >::Error: Into, +{ + let task_locals = match current_task_locals().or(task_locals) { + Some(task_locals) => { + if would_deadlock_current_loop(&task_locals)? { + return Err(PyValueError::new_err( + "deadlock: async credential callback targets the currently blocked event loop thread", + )); + } + task_locals + } + None => fallback_task_locals()?, + }; + + let fut = Python::attach(|py| { + pyo3_async_runtimes::into_future_with_locals( + &task_locals, + awaitable.bind(py).clone(), + ) + })?; + let value = fut.await?; + Python::attach(|py| value.bind(py).extract().map_err(Into::into)) } #[async_trait] @@ -190,11 +257,23 @@ impl S3CredentialsFetcher for PythonCredentialsFetcher { _ => {} } } - Python::attach(|py| { - call_pickled::(py, self.pickled_function.clone()) - .map(|c| c.into()) - }) - .map_err(|e: PyErr| e.to_string()) + let callback_result = Python::attach(|py| { + call_pickled::(py, &self.pickled_function) + }); + + let creds = match callback_result { + Ok(PythonCallbackResult::Value(creds)) => Ok(creds), + Ok(PythonCallbackResult::Awaitable(awaitable)) => { + await_python_callback::( + awaitable, + self.task_locals.clone(), + ) + .await + } + Err(err) => Err(err), + }; + + creds.map(Into::into).map_err(|e: PyErr| e.to_string()) } } @@ -217,11 +296,23 @@ impl GcsCredentialsFetcher for PythonCredentialsFetcher { _ => {} } } - Python::attach(|py| { - call_pickled::(py, self.pickled_function.clone()) - .map(|c| c.into()) - }) - .map_err(|e: PyErr| e.to_string()) + let callback_result = Python::attach(|py| { + call_pickled::(py, &self.pickled_function) + }); + + let creds = match callback_result { + Ok(PythonCallbackResult::Value(creds)) => Ok(creds), + Ok(PythonCallbackResult::Awaitable(awaitable)) => { + await_python_callback::( + awaitable, + self.task_locals.clone(), + ) + .await + } + Err(err) => Err(err), + }; + + creds.map(Into::into).map_err(|e: PyErr| e.to_string()) } } diff --git a/icechunk-python/src/lib.rs b/icechunk-python/src/lib.rs index 4426fcc26..e14aa3f0d 100644 --- a/icechunk-python/src/lib.rs +++ b/icechunk-python/src/lib.rs @@ -1,3 +1,4 @@ +mod asyncio_bridge; mod config; mod conflicts; mod errors; @@ -7,6 +8,7 @@ mod session; mod stats; mod store; mod streams; +mod sync; use std::env; diff --git a/icechunk-python/src/repository.rs b/icechunk-python/src/repository.rs index 8b9fb20d6..8d3a93c46 100644 --- a/icechunk-python/src/repository.rs +++ b/icechunk-python/src/repository.rs @@ -47,6 +47,7 @@ use crate::{ session::PySession, stats::PyChunkStorageStats, streams::PyAsyncGenerator, + sync::ensure_not_running_event_loop, }; /// Wrapper needed to implement pyo3 conversion classes @@ -815,6 +816,7 @@ impl PyRepository { dry_run: bool, delete_unused_v1_files: bool, ) -> PyResult<()> { + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let mut repo = self.0.write().await; @@ -842,6 +844,7 @@ impl PyRepository { spec_version: Option, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let repository = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { @@ -912,6 +915,7 @@ impl PyRepository { authorize_virtual_chunk_access: Option>>, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let repository = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { @@ -964,6 +968,7 @@ impl PyRepository { create_version: Option, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let repository = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { @@ -1028,6 +1033,7 @@ impl PyRepository { #[staticmethod] fn exists(py: Python<'_>, storage: PyStorage) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let exists = Repository::exists(storage.0) @@ -1051,6 +1057,7 @@ impl PyRepository { #[staticmethod] fn fetch_spec_version(py: Python<'_>, storage: PyStorage) -> PyResult> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let spec_version = Repository::fetch_spec_version(storage.0) @@ -1088,6 +1095,7 @@ impl PyRepository { Option>>, >, ) -> PyResult { + ensure_not_running_event_loop(py)?; py.detach(move || { let config = config .map(|c| c.try_into().map_err(PyValueError::new_err)) @@ -1165,6 +1173,7 @@ impl PyRepository { storage: PyStorage, ) -> PyResult> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let res = Repository::fetch_config(storage.0) @@ -1193,6 +1202,7 @@ impl PyRepository { fn save_config(&self, py: Python<'_>) -> PyResult<()> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let _etag = self @@ -1259,6 +1269,7 @@ impl PyRepository { } pub(crate) fn get_metadata(&self, py: Python<'_>) -> PyResult { + ensure_not_running_event_loop(py)?; py.detach(move || { let metadata = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { @@ -1294,6 +1305,7 @@ impl PyRepository { py: Python<'_>, metadata: PySnapshotProperties, ) -> PyResult<()> { + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { self.0 @@ -1328,6 +1340,7 @@ impl PyRepository { py: Python<'_>, metadata: PySnapshotProperties, ) -> PyResult { + ensure_not_running_event_loop(py)?; py.detach(move || { let res = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { self.0 @@ -1368,6 +1381,7 @@ impl PyRepository { snapshot_id: Option, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let version = args_to_version_info(branch, tag, snapshot_id, None)?; let ancestry = pyo3_async_runtimes::tokio::get_runtime() @@ -1392,6 +1406,7 @@ impl PyRepository { pub(crate) fn async_ops_log(&self, py: Python<'_>) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let ops = pyo3_async_runtimes::tokio::get_runtime() .block_on(async move { @@ -1417,6 +1432,7 @@ impl PyRepository { snapshot_id: &str, ) -> PyResult<()> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let snapshot_id = SnapshotId::try_from(snapshot_id).map_err(|_| { PyIcechunkStoreError::RepositoryError( @@ -1462,6 +1478,7 @@ impl PyRepository { pub(crate) fn list_branches(&self, py: Python<'_>) -> PyResult> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let branches = self @@ -1500,6 +1517,7 @@ impl PyRepository { branch_name: &str, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let tip = self @@ -1537,6 +1555,7 @@ impl PyRepository { snapshot_id: &str, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let snapshot_id = SnapshotId::try_from(snapshot_id).map_err(|_| { PyIcechunkStoreError::RepositoryError( @@ -1579,8 +1598,10 @@ impl PyRepository { pub(crate) fn list_manifest_files( &self, + py: Python<'_>, snapshot_id: &str, ) -> PyResult> { + ensure_not_running_event_loop(py)?; let snapshot_id = SnapshotId::try_from(snapshot_id).map_err(|_| { PyIcechunkStoreError::RepositoryError( RepositoryErrorKind::InvalidSnapshotId(snapshot_id.to_owned()).into(), @@ -1634,6 +1655,7 @@ impl PyRepository { from_snapshot_id: Option<&str>, ) -> PyResult<()> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let to_snapshot_id = SnapshotId::try_from(to_snapshot_id).map_err(|_| { PyIcechunkStoreError::RepositoryError( @@ -1701,6 +1723,7 @@ impl PyRepository { pub(crate) fn delete_branch(&self, py: Python<'_>, branch: &str) -> PyResult<()> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { self.0 @@ -1733,6 +1756,7 @@ impl PyRepository { pub(crate) fn delete_tag(&self, py: Python<'_>, tag: &str) -> PyResult<()> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { self.0 @@ -1770,6 +1794,7 @@ impl PyRepository { snapshot_id: &str, ) -> PyResult<()> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let snapshot_id = SnapshotId::try_from(snapshot_id).map_err(|_| { PyIcechunkStoreError::RepositoryError( @@ -1815,6 +1840,7 @@ impl PyRepository { pub(crate) fn list_tags(&self, py: Python<'_>) -> PyResult> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let tags = self @@ -1846,6 +1872,7 @@ impl PyRepository { pub(crate) fn lookup_tag(&self, py: Python<'_>, tag: &str) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let tag = self @@ -1893,6 +1920,7 @@ impl PyRepository { let to = args_to_version_info(to_branch, to_tag, to_snapshot_id, None)?; // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let diff = self @@ -1943,6 +1971,7 @@ impl PyRepository { as_of: Option>, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let version = args_to_version_info(branch, tag, snapshot_id, as_of)?; let session = @@ -1987,6 +2016,7 @@ impl PyRepository { branch: &str, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let session = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { @@ -2025,6 +2055,7 @@ impl PyRepository { branch: &str, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let session = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { @@ -2066,6 +2097,7 @@ impl PyRepository { metadata: Option, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let metadata = metadata.map(|m| m.into()); let result = @@ -2128,6 +2160,7 @@ impl PyRepository { delete_expired_tags: bool, ) -> PyResult> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let result = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { @@ -2210,6 +2243,7 @@ impl PyRepository { max_concurrent_manifest_fetches: NonZeroU16, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let result = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { @@ -2275,6 +2309,7 @@ impl PyRepository { max_concurrent_manifest_fetches: NonZeroU16, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let stats = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { @@ -2326,7 +2361,13 @@ impl PyRepository { } #[pyo3(signature = (snapshot_id, *, pretty = true))] - fn inspect_snapshot(&self, snapshot_id: String, pretty: bool) -> PyResult { + fn inspect_snapshot( + &self, + py: Python<'_>, + snapshot_id: String, + pretty: bool, + ) -> PyResult { + ensure_not_running_event_loop(py)?; let result = pyo3_async_runtimes::tokio::get_runtime() .block_on(async move { let lock = self.0.read().await; @@ -2362,11 +2403,14 @@ impl PyRepository { } #[getter] - fn spec_version(&self) -> u8 { - pyo3_async_runtimes::tokio::get_runtime().block_on(async move { - let repo = self.0.read().await; - repo.spec_version() - }) as u8 + fn spec_version(&self, py: Python<'_>) -> PyResult { + ensure_not_running_event_loop(py)?; + let spec_version = + pyo3_async_runtimes::tokio::get_runtime().block_on(async move { + let repo = self.0.read().await; + repo.spec_version() + }) as u8; + Ok(spec_version) } } diff --git a/icechunk-python/src/session.rs b/icechunk-python/src/session.rs index b5e8f0812..6bce22726 100644 --- a/icechunk-python/src/session.rs +++ b/icechunk-python/src/session.rs @@ -21,6 +21,7 @@ use crate::{ repository::{PyDiff, PySnapshotProperties}, store::PyStore, streams::PyAsyncGenerator, + sync::ensure_not_running_event_loop, }; #[pyclass] @@ -118,6 +119,7 @@ impl PySession { pub fn status(&self, py: Python<'_>) -> PyResult { // This is blocking function, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { let session = self.0.blocking_read(); @@ -152,6 +154,7 @@ impl PySession { let to = Path::new(to_path.as_str()) .map_err(|e| StoreError::from(StoreErrorKind::PathError(e))) .map_err(PyIcechunkStoreError::StoreError)?; + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let mut session = self.0.write().await; @@ -193,6 +196,7 @@ impl PySession { array_path: String, shift_chunk: Bound<'py, PyFunction>, ) -> PyResult<()> { + ensure_not_running_event_loop(py)?; let array_path = Path::new(array_path.as_str()) .map_err(|e| StoreError::from(StoreErrorKind::PathError(e))) .map_err(PyIcechunkStoreError::StoreError)?; @@ -226,7 +230,13 @@ impl PySession { }) } - pub fn shift_array(&mut self, array_path: String, offset: Vec) -> PyResult<()> { + pub fn shift_array( + &mut self, + py: Python<'_>, + array_path: String, + offset: Vec, + ) -> PyResult<()> { + ensure_not_running_event_loop(py)?; let array_path = Path::new(array_path.as_str()) .map_err(|e| StoreError::from(StoreErrorKind::PathError(e))) .map_err(PyIcechunkStoreError::StoreError)?; @@ -267,6 +277,7 @@ impl PySession { pub fn all_virtual_chunk_locations(&self, py: Python<'_>) -> PyResult> { // This is blocking function, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { let session = self.0.blocking_read(); @@ -354,9 +365,11 @@ impl PySession { pub fn chunk_type( &self, + py: Python<'_>, array_path: String, coords: Vec, ) -> PyResult { + ensure_not_running_event_loop(py)?; let session = self.0.clone(); pyo3_async_runtimes::tokio::get_runtime() .block_on(Self::chunk_type_inner(session, array_path, coords)) @@ -377,6 +390,7 @@ impl PySession { pub fn merge(&self, other: &PySession, py: Python<'_>) -> PyResult<()> { // This is blocking function, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { // TODO: bad clone let other = other.0.blocking_read().deref().clone(); @@ -421,6 +435,7 @@ impl PySession { ) -> PyResult { let metadata = metadata.map(|m| m.into()); // This is blocking function, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async { let mut session = self.0.write().await; @@ -489,6 +504,7 @@ impl PySession { ) -> PyResult { let metadata = metadata.map(|m| m.into()); // This is blocking function, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async { let mut session = self.0.write().await; @@ -532,6 +548,7 @@ impl PySession { ) -> PyResult { let metadata = metadata.map(|m| m.into()); // This is blocking function, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async { let mut session = self.0.write().await; @@ -566,6 +583,7 @@ impl PySession { pub fn rebase(&self, solver: PyConflictSolver, py: Python<'_>) -> PyResult<()> { // This is blocking function, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { let solver = solver.as_ref(); pyo3_async_runtimes::tokio::get_runtime().block_on(async { diff --git a/icechunk-python/src/sync.rs b/icechunk-python/src/sync.rs new file mode 100644 index 000000000..0c79506bf --- /dev/null +++ b/icechunk-python/src/sync.rs @@ -0,0 +1,35 @@ +use pyo3::{ + Bound, PyErr, PyResult, Python, + exceptions::{PyRuntimeError, PyValueError}, + types::{PyAny, PyAnyMethods, PyModule}, +}; + +fn object_id(py: Python<'_>, obj: &Bound<'_, PyAny>) -> Result { + let builtins = PyModule::import(py, "builtins")?; + builtins.getattr("id")?.call1((obj,))?.extract() +} + +pub(crate) fn would_deadlock_current_loop( + task_locals: &pyo3_async_runtimes::TaskLocals, +) -> Result { + Python::attach(|py| { + let running_loop = match pyo3_async_runtimes::get_running_loop(py) { + Ok(loop_ref) => loop_ref, + Err(err) if err.is_instance_of::(py) => return Ok(false), + Err(err) => return Err(err), + }; + + let target_loop = task_locals.event_loop(py); + Ok(object_id(py, &running_loop)? == object_id(py, &target_loop)?) + }) +} + +pub(crate) fn ensure_not_running_event_loop(py: Python<'_>) -> PyResult<()> { + match pyo3_async_runtimes::get_running_loop(py) { + Ok(_) => Err(PyValueError::new_err( + "synchronous API called from a running event loop thread (deadlock risk); this usage is disallowed. Use the async API or run the synchronous call in a worker thread", + )), + Err(err) if err.is_instance_of::(py) => Ok(()), + Err(err) => Err(err), + } +} diff --git a/icechunk-python/tests/conftest.py b/icechunk-python/tests/conftest.py index 87dbbf4bc..cd83c6712 100644 --- a/icechunk-python/tests/conftest.py +++ b/icechunk-python/tests/conftest.py @@ -1,3 +1,4 @@ +import asyncio from typing import Literal, cast import boto3 @@ -22,6 +23,12 @@ def parse_repo( ) +async def parse_repo_async( + store: Literal["local", "memory"], path: str, spec_version: int | None +) -> Repository: + return await asyncio.to_thread(parse_repo, store, path, spec_version) + + @pytest.fixture(scope="function") def repo( request: pytest.FixtureRequest, tmpdir: str, any_spec_version: int | None diff --git a/icechunk-python/tests/test_can_read_old.py b/icechunk-python/tests/test_can_read_old.py index f0f797889..4b36edba3 100644 --- a/icechunk-python/tests/test_can_read_old.py +++ b/icechunk-python/tests/test_can_read_old.py @@ -11,6 +11,7 @@ file as a python script: `python ./tests/test_can_read_old.py`. """ +import asyncio import shutil from datetime import UTC, datetime from typing import Any, cast @@ -86,8 +87,10 @@ async def write_a_split_repo(path: str) -> None: ) print(f"Writing repository to {store_path}") - repo = mk_repo(create=True, store_path=store_path, config=config) - session = repo.writable_session("main") + repo = await asyncio.to_thread( + mk_repo, create=True, store_path=store_path, config=config + ) + session = await repo.writable_session_async("main") store = session.store root = zarr.group(store=store) @@ -115,47 +118,49 @@ async def write_a_split_repo(path: str) -> None: fill_value=8, attributes={"this": "is a nice array", "icechunk": 1, "size": 42.0}, ) - session.commit("empty structure") + await session.commit_async("empty structure") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") big_chunks = zarr.open_array(session.store, path="/group1/split", mode="a") small_chunks = zarr.open_array(session.store, path="/group1/small_chunks", mode="a") big_chunks[:] = 120 small_chunks[:] = 0 - session.commit("write data") + await session.commit_async("write data") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") big_chunks = zarr.open_array(session.store, path="group1/split", mode="a") small_chunks = zarr.open_array(session.store, path="group1/small_chunks", mode="a") big_chunks[:] = 12 small_chunks[:] = 1 - session.commit("write data again") + await session.commit_async("write data again") ### new config config = ic.RepositoryConfig.default() config.inline_chunk_threshold_bytes = 12 config.manifest = ic.ManifestConfig(splitting=UPDATED_SPLITTING_CONFIG) - repo = mk_repo(create=False, store_path=store_path, config=config) - repo.save_config() - session = repo.writable_session("main") + repo = await asyncio.to_thread( + mk_repo, create=False, store_path=store_path, config=config + ) + await repo.save_config_async() + session = await repo.writable_session_async("main") big_chunks = zarr.open_array(session.store, path="group1/split", mode="a") small_chunks = zarr.open_array(session.store, path="group1/small_chunks", mode="a") big_chunks[:] = 14 small_chunks[:] = 3 - session.commit("write data again with more splits") + await session.commit_async("write data again with more splits") async def do_icechunk_can_read_old_repo_with_manifest_splitting(path: str) -> None: - repo = mk_repo(create=False, store_path=path) - ancestry = list(repo.ancestry(branch="main"))[::-1] + repo = await asyncio.to_thread(mk_repo, create=False, store_path=path) + ancestry = (await asyncio.to_thread(lambda: list(repo.ancestry(branch="main"))))[::-1] init_snapshot = ancestry[1] assert init_snapshot.message == "empty structure" - assert len(repo.list_manifest_files(init_snapshot.id)) == 0 + assert len(await repo.list_manifest_files_async(init_snapshot.id)) == 0 snapshot = ancestry[2] assert snapshot.message == "write data" - assert len(repo.list_manifest_files(snapshot.id)) == 9 + assert len(await repo.list_manifest_files_async(snapshot.id)) == 9 snapshot = ancestry[3] assert snapshot.message == "write data again" @@ -180,8 +185,8 @@ async def write_a_test_repo(path: str) -> None: """ print(f"Writing repository to {path}") - repo = mk_repo(create=True, store_path=path) - session = repo.writable_session("main") + repo = await asyncio.to_thread(mk_repo, create=True, store_path=path) + session = await repo.writable_session_async("main") store = session.store root = zarr.group(store=store) @@ -208,9 +213,9 @@ async def write_a_test_repo(path: str) -> None: fill_value=8, attributes={"this": "is a nice array", "icechunk": 1, "size": 42.0}, ) - session.commit("empty structure") + await session.commit_async("empty structure") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store root = zarr.group(store=store) big_chunks = cast("zarr.Array[Any]", root["group1/big_chunks"]) @@ -218,8 +223,8 @@ async def write_a_test_repo(path: str) -> None: big_chunks[:] = 42.0 small_chunks[:] = 84 - snapshot = session.commit("fill data") - session = repo.writable_session("main") + snapshot = await session.commit_async("fill data") + session = await repo.writable_session_async("main") store = session.store # We are going to write this chunk to storage as a virtual chunk @@ -236,22 +241,22 @@ async def write_a_test_repo(path: str) -> None: length=virtual_chunk_data_size, checksum=datetime(9999, 12, 31, tzinfo=UTC), ) - snapshot = session.commit("set virtual chunk") - session = repo.writable_session("main") + snapshot = await session.commit_async("set virtual chunk") + session = await repo.writable_session_async("main") store = session.store - repo.create_branch("my-branch", snapshot_id=snapshot) - session = repo.writable_session("my-branch") + await repo.create_branch_async("my-branch", snapshot_id=snapshot) + session = await repo.writable_session_async("my-branch") store = session.store await store.delete("group1/small_chunks/c/4") - snap4 = session.commit("delete a chunk") + snap4 = await session.commit_async("delete a chunk") - repo.create_tag("it works!", snapshot_id=snap4) - repo.create_tag("deleted", snapshot_id=snap4) - repo.delete_tag("deleted") + await repo.create_tag_async("it works!", snapshot_id=snap4) + await repo.create_tag_async("deleted", snapshot_id=snap4) + await repo.delete_tag_async("deleted") - session = repo.writable_session("my-branch") + session = await repo.writable_session_async("my-branch") store = session.store root = zarr.open_group(store=store) @@ -275,9 +280,9 @@ async def write_a_test_repo(path: str) -> None: fill_value=float("nan"), attributes={"this": "is a nice array", "icechunk": 1, "size": 42.0}, ) - snap5 = session.commit("some more structure") + snap5 = await session.commit_async("some more structure") - repo.create_tag("it also works!", snapshot_id=snap5) + await repo.create_tag_async("it also works!", snapshot_id=snap5) store.close() @@ -286,8 +291,8 @@ async def do_icechunk_can_read_old_repo(path: str) -> None: # we import here so it works when the script is ran by pytest from tests.conftest import write_chunks_to_minio - repo = mk_repo(create=False, store_path=path) - main_snapshot = repo.lookup_branch("main") + repo = await asyncio.to_thread(mk_repo, create=False, store_path=path) + main_snapshot = await repo.lookup_branch_async("main") expected_main_history = [ "set virtual chunk", @@ -296,7 +301,10 @@ async def do_icechunk_can_read_old_repo(path: str) -> None: "Repository initialized", ] assert [ - p.message for p in repo.ancestry(snapshot_id=main_snapshot) + p.message + for p in await asyncio.to_thread( + lambda: list(repo.ancestry(snapshot_id=main_snapshot)) + ) ] == expected_main_history expected_branch_history = [ @@ -305,21 +313,26 @@ async def do_icechunk_can_read_old_repo(path: str) -> None: ] + expected_main_history assert [ - p.message for p in repo.ancestry(branch="my-branch") + p.message + for p in await asyncio.to_thread(lambda: list(repo.ancestry(branch="my-branch"))) ] == expected_branch_history assert [ - p.message for p in repo.ancestry(tag="it also works!") + p.message + for p in await asyncio.to_thread( + lambda: list(repo.ancestry(tag="it also works!")) + ) ] == expected_branch_history - assert [p.message for p in repo.ancestry(tag="it works!")] == expected_branch_history[ - 1: - ] + assert [ + p.message + for p in await asyncio.to_thread(lambda: list(repo.ancestry(tag="it works!"))) + ] == expected_branch_history[1:] with pytest.raises(ic.IcechunkError, match="ref not found"): - repo.readonly_session(tag="deleted") + await repo.readonly_session_async(tag="deleted") - session = repo.writable_session("my-branch") + session = await repo.writable_session_async("my-branch") store = session.store assert sorted([p async for p in store.list_dir("")]) == [ "group1", @@ -374,8 +387,8 @@ async def do_icechunk_can_read_old_repo(path: str) -> None: big_chunks = cast("zarr.Array[Any]", root["group1/big_chunks"]) assert_array_equal(big_chunks[:], 42.0) - parents = list(repo.ancestry(branch="main")) - diff = repo.diff(to_branch="main", from_snapshot_id=parents[-2].id) + parents = await asyncio.to_thread(lambda: list(repo.ancestry(branch="main"))) + diff = await repo.diff_async(to_branch="main", from_snapshot_id=parents[-2].id) assert diff.new_groups == set() assert diff.new_arrays == set() assert set(diff.updated_chunks.keys()) == { diff --git a/icechunk-python/tests/test_concurrency.py b/icechunk-python/tests/test_concurrency.py index ca0261f6c..d2c4f1de4 100644 --- a/icechunk-python/tests/test_concurrency.py +++ b/icechunk-python/tests/test_concurrency.py @@ -52,12 +52,12 @@ async def list_store(store: icechunk.IcechunkStore, barrier: asyncio.Barrier) -> async def test_concurrency(any_spec_version: int | None) -> None: - repo = icechunk.Repository.open_or_create( + repo = await icechunk.Repository.open_or_create_async( storage=icechunk.in_memory_storage(), create_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store group = zarr.group(store=store, overwrite=True) @@ -85,7 +85,7 @@ async def test_concurrency(any_spec_version: int | None) -> None: all_coords = {coords async for coords in session.chunk_coordinates("/array")} assert all_coords == {(x, y) for x in range(N) for y in range(N - 1)} - _res = session.commit("commit") + _res = await session.commit_async("commit") assert isinstance(group["array"], zarr.Array) array = group["array"] @@ -150,14 +150,14 @@ async def test_thread_concurrency(any_spec_version: int | None) -> None: ) # Open the store - repo = icechunk.Repository.create( + repo = await icechunk.Repository.create_async( storage=storage, config=config, authorize_virtual_chunk_access=credentials, spec_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store group = zarr.group(store=store, overwrite=True) diff --git a/icechunk-python/tests/test_credentials.py b/icechunk-python/tests/test_credentials.py index 98e8460a9..d5776bb96 100644 --- a/icechunk-python/tests/test_credentials.py +++ b/icechunk-python/tests/test_credentials.py @@ -1,5 +1,9 @@ +import asyncio +import contextvars import pickle +import threading import time +from collections.abc import Iterator from datetime import UTC, datetime from pathlib import Path @@ -64,6 +68,31 @@ def returns_something_else() -> int: return 42 +ASYNC_CREDENTIALS_CONTEXT: contextvars.ContextVar[str | None] = contextvars.ContextVar( + "ASYNC_CREDENTIALS_CONTEXT", default=None +) + + +class AsyncGoodCredentials: + def __init__(self, path: Path, expected_context: str | None = None) -> None: + self.path = path + self.expected_context = expected_context + + async def __call__(self) -> S3StaticCredentials: + try: + calls = self.path.read_text() + except Exception: + calls = "" + self.path.write_text(calls + ".") + + if self.expected_context is not None: + if ASYNC_CREDENTIALS_CONTEXT.get() != self.expected_context: + raise IcechunkError("missing callback context") + + await asyncio.sleep(0) + return S3StaticCredentials(access_key_id="minio123", secret_access_key="minio123") + + @pytest.mark.parametrize( "scatter_initial_credentials", [False, True], @@ -222,3 +251,355 @@ def test_s3_refreshable_credentials_pickle_with_optimization( called_only_once = path.read_text() == "." assert called_only_once == scatter_initial_credentials + + +def test_async_refreshable_credentials_with_sync_repository_api( + tmp_path: Path, any_spec_version: int | None +) -> None: + prefix = "test_async_refreshable_sync_api-" + str(int(time.time() * 1000)) + + create_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + access_key_id="minio123", + secret_access_key="minio123", + ) + Repository.create(storage=create_storage, spec_version=any_spec_version) + + calls_path = tmp_path / "async_sync_calls.txt" + callback_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + get_credentials=AsyncGoodCredentials(calls_path), + ) + + repo = Repository.open(callback_storage) + assert "main" in repo.list_branches() + assert "." in calls_path.read_text() + + +def test_async_refreshable_credentials_constructed_sync_used_async( + tmp_path: Path, any_spec_version: int | None +) -> None: + prefix = "test_async_refreshable_sync_construct_async_use-" + str( + int(time.time() * 1000) + ) + + create_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + access_key_id="minio123", + secret_access_key="minio123", + ) + Repository.create(storage=create_storage, spec_version=any_spec_version) + + calls_path = tmp_path / "async_sync_construct_calls.txt" + callback_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + get_credentials=AsyncGoodCredentials(calls_path), + ) + + async def use_async_repository_api() -> None: + repo = await Repository.open_async(callback_storage) + assert "main" in await repo.list_branches_async() + + asyncio.run(use_async_repository_api()) + assert "." in calls_path.read_text() + + +def test_async_refreshable_credentials_repo_reused_across_event_loops( + tmp_path: Path, any_spec_version: int | None +) -> None: + prefix = "test_async_refreshable_new_loop-" + str(int(time.time() * 1000)) + + create_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + access_key_id="minio123", + secret_access_key="minio123", + ) + Repository.create(storage=create_storage, spec_version=any_spec_version) + + calls_path = tmp_path / "async_new_loop_calls.txt" + + async def create_and_use_repo_once() -> Repository: + callback_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + get_credentials=AsyncGoodCredentials(calls_path), + ) + repo = await Repository.open_async(callback_storage) + assert "main" in await repo.list_branches_async() + return repo + + repo = asyncio.run(create_and_use_repo_once()) + + async def use_repo_on_different_loop() -> None: + assert "main" in await repo.list_branches_async() + + asyncio.run(use_repo_on_different_loop()) + assert "." in calls_path.read_text() + + +@pytest.mark.asyncio +async def test_async_refreshable_credentials_with_async_repository_api( + tmp_path: Path, any_spec_version: int | None +) -> None: + prefix = "test_async_refreshable_async_api-" + str(int(time.time() * 1000)) + + create_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + access_key_id="minio123", + secret_access_key="minio123", + ) + await Repository.create_async(storage=create_storage, spec_version=any_spec_version) + + calls_path = tmp_path / "async_async_calls.txt" + context_value = "expected-refresh-context" + reset_token = ASYNC_CREDENTIALS_CONTEXT.set(context_value) + try: + callback_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + get_credentials=AsyncGoodCredentials( + calls_path, expected_context=context_value + ), + ) + + repo = await Repository.open_async(callback_storage) + assert "main" in await repo.list_branches_async() + finally: + ASYNC_CREDENTIALS_CONTEXT.reset(reset_token) + + assert "." in calls_path.read_text() + + +class GoodException(Exception): + pass + + +class BadException(Exception): + pass + + +# survives pickle because pickle imports the module, doesn't recreate globals +_cred_refresh_loop_ids: list[int] = [] + + +@pytest.fixture(autouse=False) +def reset_cred_tracker() -> Iterator[list[int]]: + _cred_refresh_loop_ids.clear() + yield _cred_refresh_loop_ids + _cred_refresh_loop_ids.clear() + + +@pytest.fixture() +def minio_repo_prefix(any_spec_version: int | None) -> str: + prefix = "test_async_creds-" + str(int(time.time() * 1000)) + create_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + access_key_id="minio123", + secret_access_key="minio123", + ) + Repository.create(storage=create_storage, spec_version=any_spec_version) + return prefix + + +class AsyncCredentialTracker: + def __init__( + self, + expected_loop_id: int | None = None, + return_creds: bool = False, + ): + self.expected_loop_id = expected_loop_id + self.return_creds = return_creds + + async def __call__(self) -> S3StaticCredentials: + await asyncio.sleep(0) + loop_id = id(asyncio.get_running_loop()) + _cred_refresh_loop_ids.append(loop_id) + if self.expected_loop_id is not None and loop_id != self.expected_loop_id: + raise BadException("Wrong event loop") + if self.return_creds: + return S3StaticCredentials( + access_key_id="minio123", + secret_access_key="minio123", + expires_after=datetime.now(UTC), + ) + raise GoodException("YOLO") + + +def test_async_cred_refresh_uses_same_loop( + minio_repo_prefix: str, reset_cred_tracker: list[int] +) -> None: + prefix = minio_repo_prefix + + async def run() -> None: + loop = asyncio.get_running_loop() + callback_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + get_credentials=AsyncCredentialTracker( + expected_loop_id=id(loop), return_creds=True + ), + ) + repo = await Repository.open_async(callback_storage) + assert "main" in await repo.list_branches_async() + # FIXME: this deadlocks + # _ = [snap async for snap in repo.async_ancestry(branch="main")] + snap = await repo.lookup_branch_async("main") + await repo.create_branch_async("test-write", snap) + # all that should have triggered at least one refresh + assert len(_cred_refresh_loop_ids) > 1 + assert len(set(_cred_refresh_loop_ids)) == 1 + + asyncio.run(run()) + + +def test_async_cred_refresh_uses_originator_loop_from_thread( + minio_repo_prefix: str, reset_cred_tracker: list[int] +) -> None: + prefix = minio_repo_prefix + + async def run() -> None: + loop = asyncio.get_running_loop() + callback_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + get_credentials=AsyncCredentialTracker( + expected_loop_id=id(loop), return_creds=True + ), + ) + + def sync_in_thread() -> None: + repo = Repository.open(callback_storage) + calls_after_open = len(_cred_refresh_loop_ids) + assert calls_after_open >= 1 + + repo.list_branches() + calls_after_list_branches = len(_cred_refresh_loop_ids) + assert calls_after_list_branches > calls_after_open + + list(repo.ancestry(branch="main")) + calls_after_ancestry = len(_cred_refresh_loop_ids) + assert calls_after_ancestry > calls_after_list_branches + + await loop.run_in_executor(None, sync_in_thread) + assert len(set(_cred_refresh_loop_ids)) == 1 + + asyncio.run(run()) + + +def test_async_cred_refresh_graceful_deadlock() -> None: + # Deadlock scenario: sync list_branches called directly on the event loop + # thread where the credential callback would also need to run. Use a thread + # with a timeout so the test doesn't hang forever if it actually deadlocks. + caught_error: BaseException | None = None + + async def sync_call_on_same_loop_as_callback() -> None: + loop = asyncio.get_running_loop() + current_loop_id = id(loop) + storage = s3_storage( + region="us-east-1", + bucket="testbucket", + prefix="placeholder", + get_credentials=AsyncCredentialTracker(expected_loop_id=current_loop_id), + ) + # A user can very easily make this mistake! + # we should be gracefully falling over not deadlocking + # with pytest.raises(IcechunkError, match="deadlock"): + Repository.open(storage=storage) + + def run_deadlock_check() -> None: + nonlocal caught_error + try: + asyncio.run(sync_call_on_same_loop_as_callback()) + except BaseException as e: + caught_error = e + + t = threading.Thread(target=run_deadlock_check, daemon=True) + t.start() + t.join(timeout=1) + assert not t.is_alive(), "Deadlocked: sync call blocked the event loop thread" + assert isinstance(caught_error, ValueError) + assert "deadlock" in str(caught_error) + + +def test_async_callback_no_loop_has_consistent_loop( + minio_repo_prefix: str, reset_cred_tracker: list[int] +) -> None: + prefix = minio_repo_prefix + + callback_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + get_credentials=AsyncCredentialTracker(return_creds=True), + ) + + repo = Repository.open(callback_storage) + repo.list_branches() + list(repo.ancestry(branch="main")) + assert len(_cred_refresh_loop_ids) > 3 + + def on_another_thread(r: Repository) -> None: + r.list_branches() + + t = threading.Thread(target=on_another_thread, args=(repo,)) + t.start() + t.join() + assert len(_cred_refresh_loop_ids) > 4 + assert len(set(_cred_refresh_loop_ids)) == 1 + assert ( + len(set(_cred_refresh_loop_ids)) == 1 + ), "All credential refresh calls should be on the same loop even if there isn't an event loop at the start" diff --git a/icechunk-python/tests/test_distributed_writers.py b/icechunk-python/tests/test_distributed_writers.py index a856e1258..3f8191e64 100644 --- a/icechunk-python/tests/test_distributed_writers.py +++ b/icechunk-python/tests/test_distributed_writers.py @@ -1,3 +1,4 @@ +import asyncio import time import warnings from typing import Any, cast @@ -70,8 +71,8 @@ async def test_distributed_writers( does a distributed commit. When done, we open the store again and verify we can write everything we have written. """ - repo = mk_repo(any_spec_version, use_object_store) - session = repo.writable_session(branch="main") + repo = await asyncio.to_thread(mk_repo, any_spec_version, use_object_store) + session = await repo.writable_session_async(branch="main") store = session.store shape = (CHUNKS_PER_DIM * CHUNK_DIM_SIZE,) * 2 @@ -86,22 +87,22 @@ async def test_distributed_writers( dtype="f8", fill_value=float("nan"), ) - first_snap = session.commit("array created") + first_snap = await session.commit_async("array created") - def do_writes(branch_name: str) -> None: - repo.create_branch(branch_name, first_snap) - session = repo.writable_session(branch=branch_name) + async def do_writes(branch_name: str) -> None: + await repo.create_branch_async(branch_name, first_snap) + session = await repo.writable_session_async(branch=branch_name) fork = session.fork() group = zarr.open_group(store=fork.store) zarray = cast("zarr.Array[Any]", group["array"]) merged_session = store_dask(sources=[dask_array], targets=[zarray]) - session.merge(merged_session) - commit_res = session.commit("distributed commit") + await session.merge_async(merged_session) + commit_res = await session.commit_async("distributed commit") assert commit_res async def verify(branch_name: str) -> None: # Lets open a new store to verify the results - readonly_session = repo.readonly_session(branch=branch_name) + readonly_session = await repo.readonly_session_async(branch=branch_name) store = readonly_session.store all_keys = [key async for key in store.list_prefix("/")] assert ( @@ -116,9 +117,9 @@ async def verify(branch_name: str) -> None: assert_eq(roundtripped, dask_array) # type: ignore [no-untyped-call] with Client(dashboard_address=":0"): # type: ignore[no-untyped-call] - do_writes("with-processes") + await do_writes("with-processes") await verify("with-processes") with dask.config.set(scheduler="threads"): - do_writes("with-threads") + await do_writes("with-threads") await verify("with-threads") diff --git a/icechunk-python/tests/test_gc.py b/icechunk-python/tests/test_gc.py index 46acf5b70..1b6a0b153 100644 --- a/icechunk-python/tests/test_gc.py +++ b/icechunk-python/tests/test_gc.py @@ -1,3 +1,4 @@ +import asyncio import time from datetime import UTC, datetime from typing import Any, cast @@ -31,9 +32,9 @@ def mk_repo(spec_version: int | None) -> tuple[str, ic.Repository]: @pytest.mark.filterwarnings("ignore:datetime.datetime.utcnow") @pytest.mark.parametrize("use_async", [True, False]) async def test_expire_and_gc(use_async: bool, any_spec_version: int | None) -> None: - prefix, repo = mk_repo(any_spec_version) + prefix, repo = await asyncio.to_thread(mk_repo, any_spec_version) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store group = zarr.group(store=store, overwrite=True) @@ -44,24 +45,24 @@ async def test_expire_and_gc(use_async: bool, any_spec_version: int | None) -> N dtype="i4", fill_value=-1, ) - session.commit("array created") + await session.commit_async("array created") for i in range(20): - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store group = zarr.open_group(store=store) array = cast("zarr.core.array.Array[Any]", group["array"]) array[i] = i - session.commit(f"written coord {i}") + await session.commit_async(f"written coord {i}") old = datetime.now(UTC) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store group = zarr.open_group(store=store) array = cast("zarr.core.array.Array[Any]", group["array"]) array[999] = 0 - session.commit("written coord 999") + await session.commit_async("written coord 999") client = get_minio_client() @@ -105,7 +106,7 @@ async def test_expire_and_gc(use_async: bool, any_spec_version: int | None) -> N if use_async: expired_snapshots = await repo.expire_snapshots_async(old) else: - expired_snapshots = repo.expire_snapshots(old) + expired_snapshots = await asyncio.to_thread(repo.expire_snapshots, old) # empty array + 20 old versions assert len(expired_snapshots) == 21 @@ -135,7 +136,7 @@ def space_used() -> int: if use_async: gc_result = await repo.garbage_collect_async(old, dry_run=True) else: - gc_result = repo.garbage_collect(old, dry_run=True) + gc_result = await asyncio.to_thread(repo.garbage_collect, old, dry_run=True) space_after = space_used() assert space_before == space_after @@ -154,7 +155,7 @@ def space_used() -> int: if use_async: gc_result = await repo.garbage_collect_async(old) else: - gc_result = repo.garbage_collect(old) + gc_result = await asyncio.to_thread(repo.garbage_collect, old) space_after = space_used() @@ -204,7 +205,7 @@ def space_used() -> int: ) # we can still read the array - session = repo.readonly_session(branch="main") + session = await repo.readonly_session_async(branch="main") store = session.store group = zarr.open_group(store=store, mode="r") array = cast("zarr.core.array.Array[Any]", group["array"]) diff --git a/icechunk-python/tests/test_inspect.py b/icechunk-python/tests/test_inspect.py index a04ebc325..d4ea34a6d 100644 --- a/icechunk-python/tests/test_inspect.py +++ b/icechunk-python/tests/test_inspect.py @@ -1,3 +1,4 @@ +import asyncio import json import icechunk as ic @@ -23,7 +24,7 @@ async def test_inspect_snapshot_async() -> None: repo = await ic.Repository.open_async( storage=ic.local_filesystem_storage("./tests/data/split-repo-v2") ) - snap = next(repo.ancestry(branch="main")).id + snap = (await asyncio.to_thread(lambda: next(repo.ancestry(branch="main")))).id pretty_str = await repo.inspect_snapshot_async(snap, pretty=True) non_pretty_str = await repo.inspect_snapshot_async(snap, pretty=False) diff --git a/icechunk-python/tests/test_move.py b/icechunk-python/tests/test_move.py index a536eace5..045a79847 100644 --- a/icechunk-python/tests/test_move.py +++ b/icechunk-python/tests/test_move.py @@ -1,3 +1,5 @@ +import asyncio + import numpy.testing import pytest @@ -6,10 +8,10 @@ async def test_basic_move() -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store root = zarr.group(store=store, overwrite=True) group = root.create_group("my/old/path", overwrite=True) @@ -24,11 +26,11 @@ async def test_basic_move() -> None: "my/old/path/array/zarr.json", ] ) - session.commit("create array") + await session.commit_async("create array") - session = repo.rearrange_session("main") + session = await repo.rearrange_session_async("main") store = session.store - session.move("/my/old", "/my/new") + await session.move_async("/my/old", "/my/new") all_keys = sorted([k async for k in store.list()]) assert all_keys == sorted( [ @@ -39,16 +41,16 @@ async def test_basic_move() -> None: "my/new/path/array/zarr.json", ] ) - session.commit("directory renamed") + await session.commit_async("directory renamed") - session = repo.readonly_session("main") + session = await repo.readonly_session_async("main") store = session.store group = zarr.open_group(store=store, mode="r") array = group["my/new/path/array"] numpy.testing.assert_array_equal(array, 42) - a, b, *_ = repo.ancestry(branch="main") - diff = repo.diff(from_snapshot_id=b.id, to_snapshot_id=a.id) + a, b, *_ = await asyncio.to_thread(lambda: list(repo.ancestry(branch="main"))) + diff = await repo.diff_async(from_snapshot_id=b.id, to_snapshot_id=a.id) assert diff.moved_nodes == [("/my/old", "/my/new")] assert ( repr(diff) diff --git a/icechunk-python/tests/test_regressions.py b/icechunk-python/tests/test_regressions.py index 494556fcc..33027e8bd 100644 --- a/icechunk-python/tests/test_regressions.py +++ b/icechunk-python/tests/test_regressions.py @@ -53,13 +53,13 @@ async def test_issue_418(any_spec_version: int | None) -> None: } ) - repo = Repository.create( + repo = await Repository.create_async( storage=in_memory_storage(), config=config, authorize_virtual_chunk_access=credentials, spec_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store root = zarr.Group.from_store(store=store, zarr_format=3) @@ -82,9 +82,9 @@ async def test_issue_418(any_spec_version: int | None) -> None: assert (await store._store.get("time/c/0")) == b"firs" # codespell:ignore firs assert (await store._store.get("time/c/1")) == b"econ" - session.commit("Initial commit") + await session.commit_async("Initial commit") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store root = zarr.Group.open(store=store) @@ -104,7 +104,7 @@ async def test_issue_418(any_spec_version: int | None) -> None: assert (await store._store.get("time/c/2")) == b"thir" # commit - session.commit("Append virtual ref") + await session.commit_async("Append virtual ref") assert (await store._store.get("lon/c/0")) == b"fift" assert (await store._store.get("time/c/0")) == b"firs" # codespell:ignore firs @@ -114,11 +114,11 @@ async def test_issue_418(any_spec_version: int | None) -> None: async def test_read_chunks_from_old_array(any_spec_version: int | None) -> None: # This regression appeared during the change to manifest per array - repo = Repository.create( + repo = await Repository.create_async( storage=in_memory_storage(), spec_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store root = zarr.Group.from_store(store=store, zarr_format=3) @@ -127,9 +127,9 @@ async def test_read_chunks_from_old_array(any_spec_version: int | None) -> None: array1[:] = 42 assert array1[0] == 42 print("about to commit 1") - session.commit("create array 1") + await session.commit_async("create array 1") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store root = zarr.group(store=store, zarr_format=3) # array1 = root.require_array(name="array1", shape=((2,)), chunks=((1,)), dtype="i4") @@ -137,9 +137,9 @@ async def test_read_chunks_from_old_array(any_spec_version: int | None) -> None: # assert array1[0] == 42 array2[:] = 84 print("about to commit 2") - session.commit("create array 2") + await session.commit_async("create array 2") - session = repo.readonly_session(branch="main") + session = await repo.readonly_session_async(branch="main") store = session.store root = zarr.Group.open(store=store, zarr_format=3) array1 = root.require_array(name="array1", shape=((2,)), chunks=((1,)), dtype="i4") @@ -149,11 +149,11 @@ async def test_read_chunks_from_old_array(any_spec_version: int | None) -> None: async def test_tag_with_open_session(any_spec_version: int | None) -> None: """This is an issue found by hypothesis""" - repo = Repository.create( + repo = await Repository.create_async( storage=in_memory_storage(), spec_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store await store.set( @@ -168,8 +168,8 @@ async def test_tag_with_open_session(any_spec_version: int | None) -> None: b'{\n "shape": [\n 1\n ],\n "data_type": "bool",\n "chunk_grid": {\n "name": "regular",\n "configuration": {\n "chunk_shape": [\n 1\n ]\n }\n },\n "chunk_key_encoding": {\n "name": "default",\n "configuration": {\n "separator": "/"\n }\n },\n "fill_value": false,\n "codecs": [\n {\n "name": "bytes"\n }\n ],\n "attributes": {},\n "zarr_format": 3,\n "node_type": "array",\n "storage_transformers": []\n}' ), ) - session.commit("") - session = repo.writable_session("main") + await session.commit_async("") + session = await repo.writable_session_async("main") store = session.store await store.set( @@ -193,8 +193,8 @@ async def test_tag_with_open_session(any_spec_version: int | None) -> None: ), ) - session.commit("") - session = repo.writable_session("main") + await session.commit_async("") + session = await repo.writable_session_async("main") store = session.store async for k in store.list_prefix(""): diff --git a/icechunk-python/tests/test_session.py b/icechunk-python/tests/test_session.py index a4fc03ced..bf83c5d15 100644 --- a/icechunk-python/tests/test_session.py +++ b/icechunk-python/tests/test_session.py @@ -24,11 +24,11 @@ @pytest.mark.parametrize("use_async", [True, False]) async def test_session_fork(use_async: bool, any_spec_version: int | None) -> None: with tempfile.TemporaryDirectory() as tmpdir: - repo = Repository.create( + repo = await Repository.create_async( local_filesystem_storage(tmpdir), spec_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") zarr.group(session.store) assert session.has_uncommitted_changes @@ -38,7 +38,7 @@ async def test_session_fork(use_async: bool, any_spec_version: int | None) -> No if use_async: await session.commit_async("init") else: - session.commit("init") + await session.commit_async("init") # forking a read-only session with pytest.raises(ValueError): @@ -48,29 +48,29 @@ async def test_session_fork(use_async: bool, any_spec_version: int | None) -> No pickle.loads(pickle.dumps(session.fork())) pickle.loads(pickle.dumps(session)) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") fork = pickle.loads(pickle.dumps(session.fork())) zarr.create_group(fork.store, path="/foo") assert not session.has_uncommitted_changes assert fork.has_uncommitted_changes with pytest.warns(UserWarning): with pytest.raises(IcechunkError, match="cannot commit"): - session.commit("foo") + await session.commit_async("foo") if use_async: await session.merge_async(fork) await session.commit_async("foo") else: - session.merge(fork) - session.commit("foo") + await session.merge_async(fork) + await session.commit_async("foo") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") fork1 = pickle.loads(pickle.dumps(session.fork())) fork2 = pickle.loads(pickle.dumps(session.fork())) zarr.create_group(fork1.store, path="/foo1") zarr.create_group(fork2.store, path="/foo2") with pytest.raises(TypeError, match="Cannot commit a fork"): - fork1.commit("foo") + await fork1.commit_async("foo") fork1 = pickle.loads(pickle.dumps(fork1)) fork2 = pickle.loads(pickle.dumps(fork2)) @@ -83,8 +83,8 @@ async def test_session_fork(use_async: bool, any_spec_version: int | None) -> No await session.merge_async(fork1, fork2) await session.commit_async("all done") else: - session.merge(fork1, fork2) - session.commit("all done") + await session.merge_async(fork1, fork2) + await session.commit_async("all done") groups = set( name for name, _ in zarr.open_group(session.store, mode="r").groups() @@ -92,7 +92,7 @@ async def test_session_fork(use_async: bool, any_spec_version: int | None) -> No assert groups == {"foo", "foo1", "foo2"} # forking a forked session may be useful - session = repo.writable_session("main") + session = await repo.writable_session_async("main") fork1 = pickle.loads(pickle.dumps(session.fork())) fork2 = pickle.loads(pickle.dumps(fork1.fork())) zarr.create_group(fork1.store, path="/foo3") @@ -102,7 +102,7 @@ async def test_session_fork(use_async: bool, any_spec_version: int | None) -> No fork1 = pickle.loads(pickle.dumps(fork1)) fork2 = pickle.loads(pickle.dumps(fork2)) - session.merge(fork1, fork2) + await session.merge_async(fork1, fork2) groups = set( name for name, _ in zarr.open_group(session.store, mode="r").groups() @@ -124,13 +124,13 @@ async def test_chunk_type( container = VirtualChunkContainer("file:///foo/", store_config) config.set_virtual_chunk_container(container) - repo = Repository.create( + repo = await Repository.create_async( storage=in_memory_storage(), config=config, spec_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store group = zarr.group(store=store, overwrite=True) air_temp = group.create_array("air_temp", shape=(1, 4), chunks=(1, 1), dtype="i4") @@ -154,9 +154,9 @@ async def test_chunk_type( air_temp[0, 2] = 42 assert air_temp[0, 2] == 42 - assert session.chunk_type("/air_temp", [0, 0]) == ChunkType.VIRTUAL - assert session.chunk_type("/air_temp", [0, 2]) == chunk_type - assert session.chunk_type("/air_temp", [0, 3]) == ChunkType.UNINITIALIZED + assert await session.chunk_type_async("/air_temp", [0, 0]) == ChunkType.VIRTUAL + assert await session.chunk_type_async("/air_temp", [0, 2]) == chunk_type + assert await session.chunk_type_async("/air_temp", [0, 3]) == ChunkType.UNINITIALIZED assert await session.chunk_type_async("/air_temp", [0, 0]) == ChunkType.VIRTUAL assert await session.chunk_type_async("/air_temp", [0, 2]) == chunk_type diff --git a/icechunk-python/tests/test_shift.py b/icechunk-python/tests/test_shift.py index 1f9b09bb1..017472226 100644 --- a/icechunk-python/tests/test_shift.py +++ b/icechunk-python/tests/test_shift.py @@ -1,3 +1,4 @@ +import asyncio from collections.abc import Iterable from typing import Any, cast @@ -8,18 +9,18 @@ async def test_shift_using_function() -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") root = zarr.group(store=session.store, overwrite=True) array = root.create_array( "array", shape=(50,), chunks=(2,), dtype="i4", fill_value=42 ) array[:] = np.arange(50) - session.commit("create array") + await session.commit_async("create array") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") root = zarr.group(store=session.store, overwrite=False) array = cast("zarr.Array[Any]", root["array"]) assert array[0] == 0 @@ -29,27 +30,27 @@ def reindex(idx: Iterable[int]) -> Iterable[int] | None: idx = list(idx) return [idx[0] - 4] if idx[0] >= 4 else None - session.reindex_array("/array", reindex) + await asyncio.to_thread(lambda: session.reindex_array("/array", reindex)) # we moved 4 chunks to the left, that's 8 array elements np.testing.assert_equal(array[0:42], np.arange(8, 50)) np.testing.assert_equal(array[42:], np.arange(42, 50)) async def test_shift_using_shift_by_offset() -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") root = zarr.group(store=session.store, overwrite=True) array = root.create_array( "array", shape=(50,), chunks=(2,), dtype="i4", fill_value=42 ) array[:] = np.arange(50) - session.commit("create array") + await session.commit_async("create array") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") root = zarr.group(store=session.store, overwrite=False) - session.shift_array("/array", (-4,)) + await asyncio.to_thread(lambda: session.shift_array("/array", (-4,))) array = cast("zarr.Array[Any]", root["array"]) # we moved 4 chunks to the left, that's 8 array elements np.testing.assert_equal(array[0:42], np.arange(8, 50)) @@ -57,30 +58,30 @@ async def test_shift_using_shift_by_offset() -> None: async def test_resize_and_shift_right() -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") root = zarr.group(store=session.store, overwrite=True) array = root.create_array( "array", shape=(50,), chunks=(2,), dtype="i4", fill_value=42 ) array[:] = np.arange(50) - session.commit("create array") + await session.commit_async("create array") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") root = zarr.group(store=session.store, overwrite=False) array = cast("zarr.Array[Any]", root["array"]) array.resize((100,)) assert array.shape == (100,) - session.shift_array("/array", (4,)) + await asyncio.to_thread(lambda: session.shift_array("/array", (4,))) np.testing.assert_equal(array[8:58], np.arange(50)) np.testing.assert_equal(array[0:8], np.arange(8)) assert np.all(array[58:] == 42) - session.commit("shifted") + await session.commit_async("shifted") # test still valid after commit - session = repo.readonly_session(branch="main") + session = await repo.readonly_session_async(branch="main") root = zarr.open_group(store=session.store, mode="r") array = cast("zarr.Array[Any]", root["array"]) assert array.shape == (100,) diff --git a/icechunk-python/tests/test_store.py b/icechunk-python/tests/test_store.py index 475ab86e6..bc6641aed 100644 --- a/icechunk-python/tests/test_store.py +++ b/icechunk-python/tests/test_store.py @@ -1,24 +1,25 @@ +import asyncio import json import numpy as np import icechunk as ic import zarr -from tests.conftest import parse_repo +from tests.conftest import parse_repo, parse_repo_async from zarr.core.buffer import cpu, default_buffer_prototype rng = np.random.default_rng(seed=12345) async def test_store_clear_metadata_list(any_spec_version: int | None) -> None: - repo = parse_repo("memory", "test", any_spec_version) - session = repo.writable_session("main") + repo = await parse_repo_async("memory", "test", any_spec_version) + session = await repo.writable_session_async("main") store = session.store zarr.group(store=store) - session.commit("created node /") + await session.commit_async("created node /") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store await store.clear() zarr.group(store=store) @@ -26,8 +27,8 @@ async def test_store_clear_metadata_list(any_spec_version: int | None) -> None: async def test_store_clear_chunk_list(any_spec_version: int | None) -> None: - repo = parse_repo("memory", "test", any_spec_version) - session = repo.writable_session("main") + repo = await parse_repo_async("memory", "test", any_spec_version) + session = await repo.writable_session_async("main") store = session.store array_kwargs = dict( @@ -35,9 +36,9 @@ async def test_store_clear_chunk_list(any_spec_version: int | None) -> None: ) group = zarr.group(store=store) group.create_array(**array_kwargs) # type: ignore[arg-type] - session.commit("created node /") + await session.commit_async("created node /") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store await store.clear() @@ -55,8 +56,8 @@ async def test_store_clear_chunk_list(any_spec_version: int | None) -> None: async def test_support_dimension_names_null(any_spec_version: int | None) -> None: - repo = parse_repo("memory", "test", any_spec_version) - session = repo.writable_session("main") + repo = await parse_repo_async("memory", "test", any_spec_version) + session = await repo.writable_session_async("main") store = session.store root = zarr.group(store=store) @@ -78,12 +79,12 @@ def test_doesnt_support_consolidated_metadata(any_spec_version: int | None) -> N async def test_with_readonly(any_spec_version: int | None) -> None: - repo = parse_repo("memory", "test", any_spec_version) - session = repo.readonly_session("main") + repo = await parse_repo_async("memory", "test", any_spec_version) + session = await repo.readonly_session_async("main") store = session.store assert store.read_only - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store writer = store.with_read_only(read_only=False) assert not writer._is_open @@ -99,29 +100,37 @@ async def test_with_readonly(any_spec_version: int | None) -> None: async def test_transaction(any_spec_version: int | None) -> None: - repo = parse_repo("memory", "test", any_spec_version) - cid1 = repo.lookup_branch("main") + repo = await parse_repo_async("memory", "test", any_spec_version) + cid1 = await repo.lookup_branch_async("main") + # TODO: test metadata, rebase_with, and rebase_tries kwargs - with repo.transaction("main", message="initialize group") as store: - assert not store.read_only - root = zarr.group(store=store) - root.attrs["foo"] = "bar" - cid2 = repo.lookup_branch("main") + def _run_transaction() -> None: + with repo.transaction("main", message="initialize group") as store: + assert not store.read_only + root = zarr.group(store=store) + root.attrs["foo"] = "bar" + + await asyncio.to_thread(_run_transaction) + cid2 = await repo.lookup_branch_async("main") assert cid1 != cid2, "Transaction did not commit changes" async def test_transaction_failed_no_commit(any_spec_version: int | None) -> None: - repo = parse_repo("memory", "test", any_spec_version) - cid1 = repo.lookup_branch("main") - try: + repo = await parse_repo_async("memory", "test", any_spec_version) + cid1 = await repo.lookup_branch_async("main") + + def _run_failing_transaction() -> None: with repo.transaction("main", message="initialize group") as store: assert not store.read_only root = zarr.group(store=store) root.attrs["foo"] = "bar" raise RuntimeError("Simulating an error to prevent commit") + + try: + await asyncio.to_thread(_run_failing_transaction) except RuntimeError: pass - cid2 = repo.lookup_branch("main") + cid2 = await repo.lookup_branch_async("main") assert cid1 == cid2, "Transaction committed changes despite error" diff --git a/icechunk-python/tests/test_timetravel.py b/icechunk-python/tests/test_timetravel.py index 0106c9a76..22db515ee 100644 --- a/icechunk-python/tests/test_timetravel.py +++ b/icechunk-python/tests/test_timetravel.py @@ -14,7 +14,11 @@ async def async_ancestry( repo: ic.Repository, **kwargs: str | None ) -> list[ic.SnapshotInfo]: - return [parent async for parent in repo.async_ancestry(**kwargs)] + return await asyncio.to_thread(lambda: list(repo.ancestry(**kwargs))) + + +async def first_ancestry(repo: ic.Repository, **kwargs: str | None) -> ic.SnapshotInfo: + return await asyncio.to_thread(lambda: next(repo.ancestry(**kwargs))) @pytest.mark.parametrize( @@ -235,41 +239,41 @@ def test_timetravel(using_flush: bool, any_spec_version: int | None) -> None: async def test_branch_reset(any_spec_version: int | None) -> None: config = ic.RepositoryConfig.default() config.inline_chunk_threshold_bytes = 1 - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), config=config, spec_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store group = zarr.group(store=store, overwrite=True) group.create_group("a") - prev_snapshot_id = session.commit("group a") + prev_snapshot_id = await session.commit_async("group a") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store group = zarr.open_group(store=store) group.create_group("b") - last_commit = session.commit("group b") + last_commit = await session.commit_async("group b") keys = {k async for k in store.list()} assert "a/zarr.json" in keys assert "b/zarr.json" in keys with pytest.raises(ic.IcechunkError, match="branch update conflict"): - repo.reset_branch( + await repo.reset_branch_async( "main", prev_snapshot_id, from_snapshot_id="1CECHNKREP0F1RSTCMT0" ) - assert last_commit == repo.lookup_branch("main") + assert last_commit == await repo.lookup_branch_async("main") - repo.reset_branch("main", prev_snapshot_id, from_snapshot_id=last_commit) - assert prev_snapshot_id == repo.lookup_branch("main") + await repo.reset_branch_async("main", prev_snapshot_id, from_snapshot_id=last_commit) + assert prev_snapshot_id == await repo.lookup_branch_async("main") - session = repo.readonly_session("main") + session = await repo.readonly_session_async("main") store = session.store keys = {k async for k in store.list()} @@ -282,52 +286,52 @@ async def test_branch_reset(any_spec_version: int | None) -> None: async def test_tag_delete(any_spec_version: int | None) -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), spec_version=any_spec_version, ) - snap = repo.lookup_branch("main") - repo.create_tag("tag", snap) - repo.delete_tag("tag") + snap = await repo.lookup_branch_async("main") + await repo.create_tag_async("tag", snap) + await repo.delete_tag_async("tag") with pytest.raises(ic.IcechunkError): - repo.delete_tag("tag") + await repo.delete_tag_async("tag") with pytest.raises(ic.IcechunkError): - repo.create_tag("tag", snap) + await repo.create_tag_async("tag", snap) async def test_session_with_as_of(any_spec_version: int | None) -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), spec_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store times = [] group = zarr.group(store=store, overwrite=True) - sid = session.commit("root") - times.append(next(repo.ancestry(snapshot_id=sid)).written_at) + sid = await session.commit_async("root") + times.append((await first_ancestry(repo, snapshot_id=sid)).written_at) for i in range(5): - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store group = zarr.open_group(store=store) group.create_group(f"child {i}") - sid = session.commit(f"child {i}") - times.append(next(repo.ancestry(snapshot_id=sid)).written_at) + sid = await session.commit_async(f"child {i}") + times.append((await first_ancestry(repo, snapshot_id=sid)).written_at) - ancestry = list(p for p in repo.ancestry(branch="main")) + ancestry = await async_ancestry(repo, branch="main") assert len(ancestry) == 7 # initial + root + 5 children - store = repo.readonly_session("main", as_of=times[-1]).store + store = (await repo.readonly_session_async("main", as_of=times[-1])).store group = zarr.open_group(store=store, mode="r") for i, time in enumerate(times): - store = repo.readonly_session("main", as_of=time).store + store = (await repo.readonly_session_async("main", as_of=time)).store group = zarr.open_group(store=store, mode="r") expected_children = {f"child {j}" for j in range(i)} actual_children = {g[0] for g in group.members()} @@ -335,17 +339,17 @@ async def test_session_with_as_of(any_spec_version: int | None) -> None: async def test_default_commit_metadata(any_spec_version: int | None) -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), spec_version=any_spec_version, ) - repo.set_default_commit_metadata({"user": "test"}) - session = repo.writable_session("main") + await asyncio.to_thread(lambda: repo.set_default_commit_metadata({"user": "test"})) + session = await repo.writable_session_async("main") root = zarr.group(store=session.store, overwrite=True) root.create_group("child") - sid = session.commit("root") - snap = next(repo.ancestry(snapshot_id=sid)) + sid = await session.commit_async("root") + snap = await first_ancestry(repo, snapshot_id=sid) assert snap.metadata == {"user": "test"} @@ -391,34 +395,34 @@ def test_set_metadata() -> None: async def test_set_metadata_async() -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), ) - assert repo.metadata == {} + assert await repo.get_metadata_async() == {} await repo.set_metadata_async({"user": "test"}) assert await repo.get_metadata_async() == {"user": "test"} async def test_update_metadata() -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), ) - assert repo.metadata == {} + assert await repo.get_metadata_async() == {} - repo.update_metadata({"user": "test"}) - assert repo.get_metadata() == {"user": "test"} - repo.update_metadata({"foo": 42}) - assert repo.get_metadata() == {"user": "test", "foo": 42} - repo.update_metadata({"foo": 43}) - assert repo.get_metadata() == {"user": "test", "foo": 43} + await repo.update_metadata_async({"user": "test"}) + assert await repo.get_metadata_async() == {"user": "test"} + await repo.update_metadata_async({"foo": 42}) + assert await repo.get_metadata_async() == {"user": "test", "foo": 42} + await repo.update_metadata_async({"foo": 43}) + assert await repo.get_metadata_async() == {"user": "test", "foo": 43} async def test_update_metadata_async() -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), ) - assert repo.metadata == {} + assert await repo.get_metadata_async() == {} await repo.update_metadata_async({"user": "test"}) assert await repo.get_metadata_async() == {"user": "test"} @@ -450,7 +454,7 @@ async def test_timetravel_async(using_flush: bool, any_spec_version: int | None) air_temp[:, :] = 42 assert air_temp[200, 6] == 42 - status = session.status() + status = await asyncio.to_thread(lambda: session.status()) assert status.new_groups == {"/"} assert status.new_arrays == {"/air_temp"} assert list(status.updated_chunks.keys()) == ["/air_temp"] @@ -542,9 +546,7 @@ async def test_timetravel_async(using_flush: bool, any_spec_version: int | None) air_temp = cast("zarr.core.array.Array[Any]", group["air_temp"]) assert air_temp[200, 6] == 90 - parents = [ - parent async for parent in repo.async_ancestry(snapshot_id=feature_snapshot_id) - ] + parents = await async_ancestry(repo, snapshot_id=feature_snapshot_id) assert [snap.message for snap in parents] == [ "commit 3", "commit 2", @@ -552,13 +554,16 @@ async def test_timetravel_async(using_flush: bool, any_spec_version: int | None) "Repository initialized", ] assert parents[-1].id == "1CECHNKREP0F1RSTCMT0" - assert [len(repo.list_manifest_files(snap.id)) for snap in parents] == [1, 1, 1, 0] + assert [len(await repo.list_manifest_files_async(snap.id)) for snap in parents] == [ + 1, + 1, + 1, + 0, + ] assert sorted(parents, key=lambda p: p.written_at) == list(reversed(parents)) assert len(set([snap.id for snap in parents])) == 4 - assert [parent async for parent in repo.async_ancestry(tag="v1.0")] == parents - assert [ - parent async for parent in repo.async_ancestry(branch="feature-not-dead") - ] == parents + assert await async_ancestry(repo, tag="v1.0") == parents + assert await async_ancestry(repo, branch="feature-not-dead") == parents diff = await repo.diff_async(to_tag="v1.0", from_snapshot_id=parents[-1].id) assert diff.new_groups == {"/"} @@ -602,7 +607,7 @@ async def test_timetravel_async(using_flush: bool, any_spec_version: int | None) "air_temp", shape=(1000, 1000), chunks=(100, 100), dtype="i4", overwrite=True ) assert ( - repr(session.status()) + repr(await asyncio.to_thread(lambda: session.status())) == """\ Arrays created: /air_temp @@ -622,11 +627,11 @@ async def test_timetravel_async(using_flush: bool, any_spec_version: int | None) tag_snapshot_id = await repo.lookup_tag_async("v1.0") assert tag_snapshot_id == feature_snapshot_id - actual = next(iter([parent async for parent in repo.async_ancestry(tag="v1.0")])) + actual = (await async_ancestry(repo, tag="v1.0"))[0] assert actual == await repo.lookup_snapshot_async(actual.id) if any_spec_version is None or any_spec_version > 1: - ops = [type(op) for op in repo.ops_log()] + ops = await asyncio.to_thread(lambda: [type(op) for op in repo.ops_log()]) flush_or_commit = ( [ic.NewCommitUpdate] if not using_flush @@ -711,7 +716,7 @@ async def test_session_with_as_of_async(any_spec_version: int | None) -> None: times = [] group = zarr.group(store=session.store, overwrite=True) sid = await session.commit_async("root") - to_append = await anext(repo.async_ancestry(snapshot_id=sid)) + to_append = await first_ancestry(repo, snapshot_id=sid) times.append(to_append.written_at) for i in range(5): @@ -719,10 +724,10 @@ async def test_session_with_as_of_async(any_spec_version: int | None) -> None: group = zarr.open_group(store=session.store) group.create_group(f"child {i}") sid = await session.commit_async(f"child {i}") - to_append = await anext(repo.async_ancestry(snapshot_id=sid)) + to_append = await first_ancestry(repo, snapshot_id=sid) times.append(to_append.written_at) - ancestry = [p async for p in repo.async_ancestry(branch="main")] + ancestry = await async_ancestry(repo, branch="main") assert len(ancestry) == 7 # initial + root + 5 children store = (await repo.readonly_session_async("main", as_of=times[-1])).store @@ -926,7 +931,7 @@ async def test_repository_lifecycle_async(any_spec_version: int | None) -> None: new_config_sync = ic.RepositoryConfig.default() new_config_sync.inline_chunk_threshold_bytes = 2048 - reopened_repo_sync = repo.reopen(config=new_config_sync) + reopened_repo_sync = await repo.reopen_async(config=new_config_sync) assert reopened_repo_sync.config.inline_chunk_threshold_bytes == 2048 # Test open_or_create_async with new storage (should create) @@ -965,7 +970,7 @@ async def test_rewrite_manifests_async(any_spec_version: int | None) -> None: second_commit = await session.commit_async("more data") # Get initial ancestry to verify commits exist - initial_ancestry = [snap async for snap in repo.async_ancestry(branch="main")] + initial_ancestry = await async_ancestry(repo, branch="main") assert len(initial_ancestry) >= 3 # initial + first_commit + second_commit # Test rewrite_manifests_async @@ -978,7 +983,7 @@ async def test_rewrite_manifests_async(any_spec_version: int | None) -> None: assert rewrite_commit != second_commit # Verify ancestry after rewrite - new_ancestry = [snap async for snap in repo.async_ancestry(branch="main")] + new_ancestry = await async_ancestry(repo, branch="main") extra_commits = 1 if any_spec_version == 1 else 0 # we do amend in IC2+ assert len(new_ancestry) == len(initial_ancestry) + extra_commits assert new_ancestry[0].message == "rewritten manifests" @@ -1089,7 +1094,7 @@ async def test_long_ops_log(spec_version: int | None) -> None: snap = await repo.lookup_branch_async("main") for i in range(1, NUM_BRANCHES + 1): await repo.create_branch_async(str(i), snap) - updates = [update async for update in repo.ops_log_async()] + updates = await asyncio.to_thread(lambda: list(repo.ops_log())) assert len(updates) == NUM_BRANCHES + 1 t = type(updates[0]) diff --git a/icechunk-python/tests/test_virtual_ref.py b/icechunk-python/tests/test_virtual_ref.py index 2523324d8..9aa2f4a55 100644 --- a/icechunk-python/tests/test_virtual_ref.py +++ b/icechunk-python/tests/test_virtual_ref.py @@ -60,13 +60,13 @@ async def test_write_minio_virtual_refs( ) # Open the store - repo = Repository.open_or_create( + repo = await Repository.open_or_create_async( storage=in_memory_storage(), config=config, authorize_virtual_chunk_access=credentials, create_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store array = zarr.create_array( @@ -210,7 +210,7 @@ async def test_write_minio_virtual_refs( # TODO: we should include the key and other info in the exception await store.get("c/0/0/2", prototype=buffer_prototype) - all_locations = set(session.all_virtual_chunk_locations()) + all_locations = set(await session.all_virtual_chunk_locations_async()) assert f"s3://testbucket/{prefix}/non-existing" in all_locations assert f"s3://testbucket/{prefix}/chunk-1" in all_locations assert f"s3://testbucket/{prefix}/chunk-2" in all_locations @@ -226,7 +226,7 @@ async def test_write_minio_virtual_refs( with pytest.raises(IcechunkError, match="chunk has changed"): await store.get("c/3/0/1", prototype=buffer_prototype) - _snapshot_id = session.commit("Add virtual refs") + _snapshot_id = await session.commit_async("Add virtual refs") @pytest.mark.parametrize( @@ -257,13 +257,13 @@ async def test_public_virtual_refs( container = VirtualChunkContainer(url_prefix, store_config) config.set_virtual_chunk_container(container) - repo = Repository.open_or_create( + repo = await Repository.open_or_create_async( storage=local_filesystem_storage(f"{tmpdir}/virtual-{container_type}"), config=config, authorize_virtual_chunk_access={url_prefix: None}, create_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store root = zarr.Group.from_store(store=store, zarr_format=3) diff --git a/icechunk-python/tests/test_zarr/test_store/test_core.py b/icechunk-python/tests/test_zarr/test_store/test_core.py index b6b8ace07..29bac7c90 100644 --- a/icechunk-python/tests/test_zarr/test_store/test_core.py +++ b/icechunk-python/tests/test_zarr/test_store/test_core.py @@ -1,16 +1,16 @@ from icechunk import IcechunkStore -from tests.conftest import parse_repo +from tests.conftest import parse_repo_async from zarr.storage._common import make_store_path async def test_make_store_path(any_spec_version: int | None) -> None: # Memory store - repo = parse_repo( + repo = await parse_repo_async( "memory", path="", spec_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store store_path = await make_store_path(store) assert isinstance(store_path.store, IcechunkStore) diff --git a/icechunk-python/tests/test_zarr/test_store/test_icechunk_store.py b/icechunk-python/tests/test_zarr/test_store/test_icechunk_store.py index eb27153e1..c5b2e217e 100644 --- a/icechunk-python/tests/test_zarr/test_store/test_icechunk_store.py +++ b/icechunk-python/tests/test_zarr/test_store/test_icechunk_store.py @@ -55,21 +55,21 @@ def store_kwargs(self, tmpdir: Path) -> dict[str, Any]: @pytest.fixture async def store(self, store_kwargs: dict[str, Any]) -> IcechunkStore: read_only = store_kwargs.pop("read_only") - repo = Repository.open_or_create(**store_kwargs) + repo = await Repository.open_or_create_async(**store_kwargs) if read_only: - session = repo.readonly_session(branch="main") + session = await repo.readonly_session_async(branch="main") else: - session = repo.writable_session("main") + session = await repo.writable_session_async("main") return session.store @pytest.fixture async def store_not_open(self, store_kwargs: dict[str, Any]) -> IcechunkStore: read_only = store_kwargs.pop("read_only") - repo = Repository.open_or_create(**store_kwargs) + repo = await Repository.open_or_create_async(**store_kwargs) if read_only: - session = repo.readonly_session(branch="main") + session = await repo.readonly_session_async(branch="main") else: - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store store._is_open = False @@ -98,11 +98,11 @@ def test_store_repr(self, store: IcechunkStore) -> None: async def test_store_open_read_only( self, store: IcechunkStore, store_kwargs: dict[str, Any], read_only: bool ) -> None: - repo = Repository.open(**store_kwargs) + repo = await Repository.open_async(**store_kwargs) if read_only: - session = repo.readonly_session(branch="main") + session = await repo.readonly_session_async(branch="main") else: - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store assert store._is_open assert store.read_only == read_only @@ -111,8 +111,8 @@ async def test_read_only_store_raises( # type: ignore[override] self, store: IcechunkStore, store_kwargs: dict[str, Any] ) -> None: kwargs = {**store_kwargs} - repo = Repository.open(**kwargs) - session = repo.readonly_session(branch="main") + repo = await Repository.open_async(**kwargs) + session = await repo.readonly_session_async(branch="main") store = session.store assert store.read_only