From b9d84214e59c77f350f0ce936356e50b227061b1 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Wed, 11 Feb 2026 16:46:30 -0500 Subject: [PATCH 01/14] Support async refreshable credential callbacks in Python bindings --- .../python/icechunk/credentials.py | 51 ++++++-- icechunk-python/python/icechunk/storage.py | 26 ++-- icechunk-python/src/config.rs | 118 +++++++++++++++--- icechunk-python/tests/test_credentials.py | 102 +++++++++++++++ 4 files changed, 260 insertions(+), 37 deletions(-) diff --git a/icechunk-python/python/icechunk/credentials.py b/icechunk-python/python/icechunk/credentials.py index e3a603095..2259b788e 100644 --- a/icechunk-python/python/icechunk/credentials.py +++ b/icechunk-python/python/icechunk/credentials.py @@ -1,7 +1,9 @@ +import asyncio +import inspect import pickle -from collections.abc import Callable, Mapping +from collections.abc import Awaitable, Callable, Mapping from datetime import datetime -from typing import cast +from typing import TypeVar, cast from icechunk._icechunk_python import ( AzureCredentials, @@ -46,16 +48,43 @@ AnyCredential = Credentials.S3 | Credentials.Gcs | Credentials.Azure +S3CredentialProvider = Callable[ + [], S3StaticCredentials | Awaitable[S3StaticCredentials] +] +GcsCredentialProvider = Callable[ + [], GcsBearerCredential | Awaitable[GcsBearerCredential] +] +CredentialType = TypeVar("CredentialType") + + +def _resolve_callback_once( + get_credentials: Callable[[], CredentialType | Awaitable[CredentialType]], +) -> CredentialType: + value = get_credentials() + if inspect.isawaitable(value): + try: + return asyncio.run(cast(Awaitable[CredentialType], value)) + except RuntimeError as err: + if "asyncio.run() cannot be called from a running event loop" in str(err): + raise ValueError( + "scatter_initial_credentials=True cannot eagerly evaluate async " + "credential callbacks while an event loop is running. " + "Set scatter_initial_credentials=False in async contexts." + ) from err + raise + + return cast(CredentialType, value) + def s3_refreshable_credentials( - get_credentials: Callable[[], S3StaticCredentials], + get_credentials: S3CredentialProvider, scatter_initial_credentials: bool = False, ) -> S3Credentials.Refreshable: """Create refreshable credentials for S3 and S3 compatible object stores. Parameters ---------- - get_credentials: Callable[[], S3StaticCredentials] + get_credentials: Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] Use this function to get and refresh the credentials. The function must be pickable. scatter_initial_credentials: bool, optional Immediately call and store the value returned by get_credentials. This is useful if the @@ -64,7 +93,7 @@ def s3_refreshable_credentials( set of credentials has expired, the cached value is no longer used. Notice that credentials obtained are stored, and they can be sent over the network if you pickle the session/repo. """ - current = get_credentials() if scatter_initial_credentials else None + current = _resolve_callback_once(get_credentials) if scatter_initial_credentials else None return S3Credentials.Refreshable(pickle.dumps(get_credentials), current) @@ -116,7 +145,7 @@ def s3_credentials( expires_after: datetime | None = None, anonymous: bool | None = None, from_env: bool | None = None, - get_credentials: Callable[[], S3StaticCredentials] | None = None, + get_credentials: S3CredentialProvider | None = None, scatter_initial_credentials: bool = False, ) -> AnyS3Credential: """Create credentials for S3 and S3 compatible object stores. @@ -137,7 +166,7 @@ def s3_credentials( If set to True requests to the object store will not be signed from_env: bool | None Fetch credentials from the operative system environment - get_credentials: Callable[[], S3StaticCredentials] | None + get_credentials: Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None Use this function to get and refresh object store credentials scatter_initial_credentials: bool, optional Immediately call and store the value returned by get_credentials. This is useful if the @@ -218,14 +247,14 @@ def gcs_static_credentials( def gcs_refreshable_credentials( - get_credentials: Callable[[], GcsBearerCredential], + get_credentials: GcsCredentialProvider, scatter_initial_credentials: bool = False, ) -> GcsCredentials.Refreshable: """Create refreshable credentials for Google Cloud Storage object store. Parameters ---------- - get_credentials: Callable[[], S3StaticCredentials] + get_credentials: Callable[[], GcsBearerCredential | Awaitable[GcsBearerCredential]] Use this function to get and refresh the credentials. The function must be pickable. scatter_initial_credentials: bool, optional Immediately call and store the value returned by get_credentials. This is useful if the @@ -235,7 +264,7 @@ def gcs_refreshable_credentials( obtained are stored, and they can be sent over the network if you pickle the session/repo. """ - current = get_credentials() if scatter_initial_credentials else None + current = _resolve_callback_once(get_credentials) if scatter_initial_credentials else None return GcsCredentials.Refreshable(pickle.dumps(get_credentials), current) @@ -257,7 +286,7 @@ def gcs_credentials( bearer_token: str | None = None, from_env: bool | None = None, anonymous: bool | None = None, - get_credentials: Callable[[], GcsBearerCredential] | None = None, + get_credentials: GcsCredentialProvider | None = None, scatter_initial_credentials: bool = False, ) -> AnyGcsCredential: """Create credentials Google Cloud Storage object store. diff --git a/icechunk-python/python/icechunk/storage.py b/icechunk-python/python/icechunk/storage.py index dbe0007e1..d66bc1c8a 100644 --- a/icechunk-python/python/icechunk/storage.py +++ b/icechunk-python/python/icechunk/storage.py @@ -1,4 +1,4 @@ -from collections.abc import Callable +from collections.abc import Awaitable, Callable from datetime import datetime from icechunk._icechunk_python import ( @@ -125,7 +125,9 @@ def s3_storage( expires_after: datetime | None = None, anonymous: bool | None = None, from_env: bool | None = None, - get_credentials: Callable[[], S3StaticCredentials] | None = None, + get_credentials: ( + Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None + ) = None, scatter_initial_credentials: bool = False, force_path_style: bool = False, network_stream_timeout_seconds: int = 60, @@ -157,7 +159,7 @@ def s3_storage( If set to True requests to the object store will not be signed from_env: bool | None Fetch credentials from the operative system environment - get_credentials: Callable[[], S3StaticCredentials] | None + get_credentials: Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None Use this function to get and refresh object store credentials scatter_initial_credentials: bool, optional Immediately call and store the value returned by get_credentials. This is useful if the @@ -254,7 +256,9 @@ def tigris_storage( expires_after: datetime | None = None, anonymous: bool | None = None, from_env: bool | None = None, - get_credentials: Callable[[], S3StaticCredentials] | None = None, + get_credentials: ( + Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None + ) = None, scatter_initial_credentials: bool = False, network_stream_timeout_seconds: int = 60, ) -> Storage: @@ -288,7 +292,7 @@ def tigris_storage( If set to True requests to the object store will not be signed from_env: bool | None Fetch credentials from the operative system environment - get_credentials: Callable[[], S3StaticCredentials] | None + get_credentials: Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None Use this function to get and refresh object store credentials scatter_initial_credentials: bool, optional Immediately call and store the value returned by get_credentials. This is useful if the @@ -340,7 +344,9 @@ def r2_storage( expires_after: datetime | None = None, anonymous: bool | None = None, from_env: bool | None = None, - get_credentials: Callable[[], S3StaticCredentials] | None = None, + get_credentials: ( + Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None + ) = None, scatter_initial_credentials: bool = False, network_stream_timeout_seconds: int = 60, ) -> Storage: @@ -374,7 +380,7 @@ def r2_storage( If set to True requests to the object store will not be signed from_env: bool | None Fetch credentials from the operative system environment - get_credentials: Callable[[], S3StaticCredentials] | None + get_credentials: Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] | None Use this function to get and refresh object store credentials scatter_initial_credentials: bool, optional Immediately call and store the value returned by get_credentials. This is useful if the @@ -436,7 +442,9 @@ def gcs_storage( anonymous: bool | None = None, from_env: bool | None = None, config: dict[str, str] | None = None, - get_credentials: Callable[[], GcsBearerCredential] | None = None, + get_credentials: ( + Callable[[], GcsBearerCredential | Awaitable[GcsBearerCredential]] | None + ) = None, scatter_initial_credentials: bool = False, ) -> Storage: """Create a Storage instance that saves data in Google Cloud Storage object store. @@ -461,7 +469,7 @@ def gcs_storage( Fetch credentials from the operative system environment config: dict[str, str] | None A dictionary of options for the Google Cloud Storage object store. See https://docs.rs/object_store/latest/object_store/gcp/enum.GoogleConfigKey.html#variants for a list of possible configuration keys. - get_credentials: Callable[[], GcsBearerCredential] | None + get_credentials: Callable[[], GcsBearerCredential | Awaitable[GcsBearerCredential]] | None Use this function to get and refresh object store credentials scatter_initial_credentials: bool, optional Immediately call and store the value returned by get_credentials. This is useful if the diff --git a/icechunk-python/src/config.rs b/icechunk-python/src/config.rs index 68e80a0c1..b7bedc02f 100644 --- a/icechunk-python/src/config.rs +++ b/icechunk-python/src/config.rs @@ -29,7 +29,7 @@ use icechunk::{ }; use pyo3::{ Bound, FromPyObject, Py, PyErr, PyResult, Python, pyclass, pymethods, - types::{PyAnyMethods, PyModule, PyType}, + types::{PyAny, PyAnyMethods, PyModule, PyType}, }; use crate::errors::PyIcechunkStoreError; @@ -141,25 +141,49 @@ pub(crate) fn datetime_repr(d: &DateTime) -> String { struct PythonCredentialsFetcher { pub pickled_function: Vec, pub initial: Option, + #[serde(skip, default)] + pub task_locals: Option, } impl PythonCredentialsFetcher { fn new(pickled_function: Vec) -> Self { - PythonCredentialsFetcher { pickled_function, initial: None } + PythonCredentialsFetcher { + pickled_function, + initial: None, + task_locals: current_task_locals(), + } } fn new_with_initial(pickled_function: Vec, current: C) -> Self where C: Into, { - PythonCredentialsFetcher { pickled_function, initial: Some(current.into()) } + PythonCredentialsFetcher { + pickled_function, + initial: Some(current.into()), + task_locals: current_task_locals(), + } } } +fn current_task_locals() -> Option { + Python::attach(|py| pyo3_async_runtimes::tokio::get_current_locals(py).ok()) +} + +enum PythonCallbackResult { + Value(PyCred), + Awaitable(Py), +} + +fn is_awaitable(py: Python<'_>, value: &Bound<'_, PyAny>) -> Result { + let inspect = PyModule::import(py, "inspect")?; + inspect.getattr("isawaitable")?.call1((value,))?.extract() +} + fn call_pickled( py: Python<'_>, - pickled_function: Vec, -) -> Result + pickled_function: &[u8], +) -> Result, PyErr> where PyCred: for<'a, 'py> FromPyObject<'a, 'py>, for<'a, 'py> >::Error: Into, @@ -167,8 +191,44 @@ where let pickle_module = PyModule::import(py, "pickle")?; let loads_function = pickle_module.getattr("loads")?; let fetcher = loads_function.call1((pickled_function,))?; - let creds: PyCred = fetcher.call0()?.extract().map_err(Into::into)?; - Ok(creds) + let value = fetcher.call0()?; + let value_for_extract = value.clone(); + let extract_err = match value_for_extract.extract::() { + Ok(creds) => return Ok(PythonCallbackResult::Value(creds)), + Err(extract_err) => extract_err, + }; + + if is_awaitable(py, &value)? { + Ok(PythonCallbackResult::Awaitable(value.unbind())) + } else { + Err(extract_err.into()) + } +} + +async fn await_python_callback( + awaitable: Py, + task_locals: Option, +) -> Result +where + PyCred: for<'a, 'py> FromPyObject<'a, 'py>, + for<'a, 'py> >::Error: Into, +{ + if let Some(task_locals) = task_locals { + let fut = Python::attach(|py| { + pyo3_async_runtimes::into_future_with_locals( + &task_locals, + awaitable.bind(py).clone(), + ) + })?; + let value = fut.await?; + return Python::attach(|py| value.bind(py).extract().map_err(Into::into)); + } + + Python::attach(|py| { + let asyncio = PyModule::import(py, "asyncio")?; + let value = asyncio.getattr("run")?.call1((awaitable.bind(py),))?; + value.extract().map_err(Into::into) + }) } #[async_trait] @@ -190,11 +250,23 @@ impl S3CredentialsFetcher for PythonCredentialsFetcher { _ => {} } } - Python::attach(|py| { - call_pickled::(py, self.pickled_function.clone()) - .map(|c| c.into()) - }) - .map_err(|e: PyErr| e.to_string()) + let callback_result = Python::attach(|py| { + call_pickled::(py, &self.pickled_function) + }); + + let creds = match callback_result { + Ok(PythonCallbackResult::Value(creds)) => Ok(creds), + Ok(PythonCallbackResult::Awaitable(awaitable)) => { + await_python_callback::( + awaitable, + self.task_locals.clone(), + ) + .await + } + Err(err) => Err(err), + }; + + creds.map(Into::into).map_err(|e: PyErr| e.to_string()) } } @@ -217,11 +289,23 @@ impl GcsCredentialsFetcher for PythonCredentialsFetcher { _ => {} } } - Python::attach(|py| { - call_pickled::(py, self.pickled_function.clone()) - .map(|c| c.into()) - }) - .map_err(|e: PyErr| e.to_string()) + let callback_result = Python::attach(|py| { + call_pickled::(py, &self.pickled_function) + }); + + let creds = match callback_result { + Ok(PythonCallbackResult::Value(creds)) => Ok(creds), + Ok(PythonCallbackResult::Awaitable(awaitable)) => { + await_python_callback::( + awaitable, + self.task_locals.clone(), + ) + .await + } + Err(err) => Err(err), + }; + + creds.map(Into::into).map_err(|e: PyErr| e.to_string()) } } diff --git a/icechunk-python/tests/test_credentials.py b/icechunk-python/tests/test_credentials.py index 98e8460a9..bdc71555b 100644 --- a/icechunk-python/tests/test_credentials.py +++ b/icechunk-python/tests/test_credentials.py @@ -1,3 +1,5 @@ +import asyncio +import contextvars import pickle import time from datetime import UTC, datetime @@ -64,6 +66,31 @@ def returns_something_else() -> int: return 42 +ASYNC_CREDENTIALS_CONTEXT: contextvars.ContextVar[str | None] = ( + contextvars.ContextVar("ASYNC_CREDENTIALS_CONTEXT", default=None) +) + + +class AsyncGoodCredentials: + def __init__(self, path: Path, expected_context: str | None = None) -> None: + self.path = path + self.expected_context = expected_context + + async def __call__(self) -> S3StaticCredentials: + try: + calls = self.path.read_text() + except Exception: + calls = "" + self.path.write_text(calls + ".") + + if self.expected_context is not None: + if ASYNC_CREDENTIALS_CONTEXT.get() != self.expected_context: + raise IcechunkError("missing callback context") + + await asyncio.sleep(0) + return S3StaticCredentials(access_key_id="minio123", secret_access_key="minio123") + + @pytest.mark.parametrize( "scatter_initial_credentials", [False, True], @@ -222,3 +249,78 @@ def test_s3_refreshable_credentials_pickle_with_optimization( called_only_once = path.read_text() == "." assert called_only_once == scatter_initial_credentials + + +def test_async_refreshable_credentials_with_sync_repository_api( + tmp_path: Path, any_spec_version: int | None +) -> None: + prefix = "test_async_refreshable_sync_api-" + str(int(time.time() * 1000)) + + create_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + access_key_id="minio123", + secret_access_key="minio123", + ) + Repository.create(storage=create_storage, spec_version=any_spec_version) + + calls_path = tmp_path / "async_sync_calls.txt" + callback_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + get_credentials=AsyncGoodCredentials(calls_path), + ) + + repo = Repository.open(callback_storage) + assert "main" in repo.list_branches() + assert calls_path.read_text() != "" + + +@pytest.mark.asyncio +async def test_async_refreshable_credentials_with_async_repository_api( + tmp_path: Path, any_spec_version: int | None +) -> None: + prefix = "test_async_refreshable_async_api-" + str(int(time.time() * 1000)) + + create_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + access_key_id="minio123", + secret_access_key="minio123", + ) + Repository.create(storage=create_storage, spec_version=any_spec_version) + + calls_path = tmp_path / "async_async_calls.txt" + context_value = "expected-refresh-context" + reset_token = ASYNC_CREDENTIALS_CONTEXT.set(context_value) + try: + callback_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + get_credentials=AsyncGoodCredentials( + calls_path, expected_context=context_value + ), + ) + + repo = await Repository.open_async(callback_storage) + assert "main" in await repo.list_branches_async() + finally: + ASYNC_CREDENTIALS_CONTEXT.reset(reset_token) + + assert calls_path.read_text() != "" From fa1b31309e41c1c6dc3cdd9306ad3a0e7a49b6f0 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Thu, 12 Feb 2026 11:50:29 -0500 Subject: [PATCH 02/14] More handling for async event loops --- .../python/icechunk/credentials.py | 23 ++-- icechunk-python/python/icechunk/repository.py | 13 +++ icechunk-python/src/config.rs | 2 +- icechunk-python/tests/test_credentials.py | 109 +++++++++++++++++- 4 files changed, 134 insertions(+), 13 deletions(-) diff --git a/icechunk-python/python/icechunk/credentials.py b/icechunk-python/python/icechunk/credentials.py index 2259b788e..1a4758d19 100644 --- a/icechunk-python/python/icechunk/credentials.py +++ b/icechunk-python/python/icechunk/credentials.py @@ -1,7 +1,7 @@ import asyncio import inspect import pickle -from collections.abc import Awaitable, Callable, Mapping +from collections.abc import Awaitable, Callable, Coroutine, Mapping from datetime import datetime from typing import TypeVar, cast @@ -48,12 +48,8 @@ AnyCredential = Credentials.S3 | Credentials.Gcs | Credentials.Azure -S3CredentialProvider = Callable[ - [], S3StaticCredentials | Awaitable[S3StaticCredentials] -] -GcsCredentialProvider = Callable[ - [], GcsBearerCredential | Awaitable[GcsBearerCredential] -] +S3CredentialProvider = Callable[[], S3StaticCredentials | Awaitable[S3StaticCredentials]] +GcsCredentialProvider = Callable[[], GcsBearerCredential | Awaitable[GcsBearerCredential]] CredentialType = TypeVar("CredentialType") @@ -62,10 +58,13 @@ def _resolve_callback_once( ) -> CredentialType: value = get_credentials() if inspect.isawaitable(value): + awaitable = cast(Awaitable[CredentialType], value) try: - return asyncio.run(cast(Awaitable[CredentialType], value)) + return asyncio.run(awaitable) except RuntimeError as err: if "asyncio.run() cannot be called from a running event loop" in str(err): + if inspect.iscoroutine(awaitable): + cast(Coroutine[object, object, CredentialType], awaitable).close() raise ValueError( "scatter_initial_credentials=True cannot eagerly evaluate async " "credential callbacks while an event loop is running. " @@ -93,7 +92,9 @@ def s3_refreshable_credentials( set of credentials has expired, the cached value is no longer used. Notice that credentials obtained are stored, and they can be sent over the network if you pickle the session/repo. """ - current = _resolve_callback_once(get_credentials) if scatter_initial_credentials else None + current = ( + _resolve_callback_once(get_credentials) if scatter_initial_credentials else None + ) return S3Credentials.Refreshable(pickle.dumps(get_credentials), current) @@ -264,7 +265,9 @@ def gcs_refreshable_credentials( obtained are stored, and they can be sent over the network if you pickle the session/repo. """ - current = _resolve_callback_once(get_credentials) if scatter_initial_credentials else None + current = ( + _resolve_callback_once(get_credentials) if scatter_initial_credentials else None + ) return GcsCredentials.Refreshable(pickle.dumps(get_credentials), current) diff --git a/icechunk-python/python/icechunk/repository.py b/icechunk-python/python/icechunk/repository.py index 6ef06c074..1f2e68812 100644 --- a/icechunk-python/python/icechunk/repository.py +++ b/icechunk-python/python/icechunk/repository.py @@ -1,3 +1,4 @@ +import asyncio import datetime import warnings from collections.abc import AsyncIterator, Iterator @@ -21,6 +22,17 @@ from icechunk.store import IcechunkStore +def _raise_if_running_loop(sync_method: str, async_method: str) -> None: + try: + asyncio.get_running_loop() + except RuntimeError: + return + raise RuntimeError( + f"`Repository.{sync_method}` cannot be called while an asyncio event loop is " + f"running. Use `await Repository.{async_method}` instead." + ) + + class Repository: """An Icechunk repository.""" @@ -793,6 +805,7 @@ def list_branches(self) -> set[str]: set[str] A set of branch names. """ + _raise_if_running_loop("list_branches", "list_branches_async") return self._repository.list_branches() async def list_branches_async(self) -> set[str]: diff --git a/icechunk-python/src/config.rs b/icechunk-python/src/config.rs index b7bedc02f..2be39de9e 100644 --- a/icechunk-python/src/config.rs +++ b/icechunk-python/src/config.rs @@ -213,7 +213,7 @@ where PyCred: for<'a, 'py> FromPyObject<'a, 'py>, for<'a, 'py> >::Error: Into, { - if let Some(task_locals) = task_locals { + if let Some(task_locals) = current_task_locals().or(task_locals) { let fut = Python::attach(|py| { pyo3_async_runtimes::into_future_with_locals( &task_locals, diff --git a/icechunk-python/tests/test_credentials.py b/icechunk-python/tests/test_credentials.py index bdc71555b..dbcc4bae7 100644 --- a/icechunk-python/tests/test_credentials.py +++ b/icechunk-python/tests/test_credentials.py @@ -66,8 +66,8 @@ def returns_something_else() -> int: return 42 -ASYNC_CREDENTIALS_CONTEXT: contextvars.ContextVar[str | None] = ( - contextvars.ContextVar("ASYNC_CREDENTIALS_CONTEXT", default=None) +ASYNC_CREDENTIALS_CONTEXT: contextvars.ContextVar[str | None] = contextvars.ContextVar( + "ASYNC_CREDENTIALS_CONTEXT", default=None ) @@ -284,6 +284,111 @@ def test_async_refreshable_credentials_with_sync_repository_api( assert calls_path.read_text() != "" +def test_async_refreshable_credentials_constructed_sync_used_async( + tmp_path: Path, any_spec_version: int | None +) -> None: + prefix = "test_async_refreshable_sync_construct_async_use-" + str( + int(time.time() * 1000) + ) + + create_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + access_key_id="minio123", + secret_access_key="minio123", + ) + Repository.create(storage=create_storage, spec_version=any_spec_version) + + calls_path = tmp_path / "async_sync_construct_calls.txt" + callback_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + get_credentials=AsyncGoodCredentials(calls_path), + ) + + async def use_async_repository_api() -> None: + repo = await Repository.open_async(callback_storage) + assert "main" in await repo.list_branches_async() + + asyncio.run(use_async_repository_api()) + assert calls_path.read_text() != "" + + +def test_async_refreshable_credentials_repo_reused_across_event_loops( + tmp_path: Path, any_spec_version: int | None +) -> None: + prefix = "test_async_refreshable_new_loop-" + str(int(time.time() * 1000)) + + create_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + access_key_id="minio123", + secret_access_key="minio123", + ) + Repository.create(storage=create_storage, spec_version=any_spec_version) + + calls_path = tmp_path / "async_new_loop_calls.txt" + + async def create_and_use_repo_once() -> Repository: + callback_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + get_credentials=AsyncGoodCredentials(calls_path), + ) + repo = await Repository.open_async(callback_storage) + assert "main" in await repo.list_branches_async() + return repo + + repo = asyncio.run(create_and_use_repo_once()) + + async def use_repo_on_different_loop() -> None: + assert "main" in await repo.list_branches_async() + + asyncio.run(use_repo_on_different_loop()) + assert calls_path.read_text() != "" + + +@pytest.mark.asyncio +async def test_sync_list_branches_in_async_context_errors( + any_spec_version: int | None, +) -> None: + prefix = "test_sync_list_branches_in_async_context_errors-" + str( + int(time.time() * 1000) + ) + + create_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + access_key_id="minio123", + secret_access_key="minio123", + ) + Repository.create(storage=create_storage, spec_version=any_spec_version) + + repo = await Repository.open_async(create_storage) + with pytest.raises(RuntimeError, match="list_branches_async"): + repo.list_branches() + + @pytest.mark.asyncio async def test_async_refreshable_credentials_with_async_repository_api( tmp_path: Path, any_spec_version: int | None From 64e8dd8f5df22bd6cf4140459a0468a6622ce173 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Thu, 12 Feb 2026 11:58:41 -0500 Subject: [PATCH 03/14] mypy --- .../python/icechunk/credentials.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/icechunk-python/python/icechunk/credentials.py b/icechunk-python/python/icechunk/credentials.py index 1a4758d19..7cd31fe9f 100644 --- a/icechunk-python/python/icechunk/credentials.py +++ b/icechunk-python/python/icechunk/credentials.py @@ -3,7 +3,7 @@ import pickle from collections.abc import Awaitable, Callable, Coroutine, Mapping from datetime import datetime -from typing import TypeVar, cast +from typing import Any, TypeVar, cast from icechunk._icechunk_python import ( AzureCredentials, @@ -58,13 +58,21 @@ def _resolve_callback_once( ) -> CredentialType: value = get_credentials() if inspect.isawaitable(value): - awaitable = cast(Awaitable[CredentialType], value) + if inspect.iscoroutine(value): + coroutine = cast(Coroutine[Any, Any, CredentialType], value) + else: + awaitable = cast(Awaitable[CredentialType], value) + + async def _as_coroutine() -> CredentialType: + return await awaitable + + coroutine = _as_coroutine() + try: - return asyncio.run(awaitable) + return asyncio.run(coroutine) except RuntimeError as err: if "asyncio.run() cannot be called from a running event loop" in str(err): - if inspect.iscoroutine(awaitable): - cast(Coroutine[object, object, CredentialType], awaitable).close() + coroutine.close() raise ValueError( "scatter_initial_credentials=True cannot eagerly evaluate async " "credential callbacks while an event loop is running. " @@ -72,7 +80,7 @@ def _resolve_callback_once( ) from err raise - return cast(CredentialType, value) + return value def s3_refreshable_credentials( From ace03cee30d80d2bca23dcee61f8e7a8d6d2f94c Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Thu, 12 Feb 2026 12:07:06 -0500 Subject: [PATCH 04/14] More protection for async usage --- icechunk-python/python/icechunk/repository.py | 44 ++++++++++++++++++- icechunk-python/tests/test_credentials.py | 41 ++++++++++++++++- 2 files changed, 81 insertions(+), 4 deletions(-) diff --git a/icechunk-python/python/icechunk/repository.py b/icechunk-python/python/icechunk/repository.py index 1f2e68812..6d59a4148 100644 --- a/icechunk-python/python/icechunk/repository.py +++ b/icechunk-python/python/icechunk/repository.py @@ -22,14 +22,21 @@ from icechunk.store import IcechunkStore -def _raise_if_running_loop(sync_method: str, async_method: str) -> None: +def _raise_if_running_loop( + sync_method: str, async_method: str, *, await_async: bool = True +) -> None: try: asyncio.get_running_loop() except RuntimeError: return + guidance = ( + f"Use `await Repository.{async_method}` instead." + if await_async + else f"Use `Repository.{async_method}` instead." + ) raise RuntimeError( f"`Repository.{sync_method}` cannot be called while an asyncio event loop is " - f"running. Use `await Repository.{async_method}` instead." + f"running. {guidance}" ) @@ -79,6 +86,7 @@ def create( Self An instance of the Repository class. """ + _raise_if_running_loop("create", "create_async") return cls( PyRepository.create( storage, @@ -171,6 +179,7 @@ def open( Self An instance of the Repository class. """ + _raise_if_running_loop("open", "open_async") return cls( PyRepository.open( storage, @@ -266,6 +275,7 @@ def open_or_create( Self An instance of the Repository class. """ + _raise_if_running_loop("open_or_create", "open_or_create_async") return cls( PyRepository.open_or_create( storage, @@ -341,6 +351,7 @@ def exists(storage: Storage) -> bool: bool True if the repository exists, False otherwise. """ + _raise_if_running_loop("exists", "exists_async") return PyRepository.exists(storage) @staticmethod @@ -379,6 +390,7 @@ def fetch_spec_version(storage: Storage) -> int | None: The spec version of the repository if it exists, None if no repository exists at the given location. """ + _raise_if_running_loop("fetch_spec_version", "fetch_spec_version_async") return PyRepository.fetch_spec_version(storage) @staticmethod @@ -427,6 +439,7 @@ def fetch_config(storage: Storage) -> RepositoryConfig | None: RepositoryConfig | None The repository configuration if it exists, None otherwise. """ + _raise_if_running_loop("fetch_config", "fetch_config_async") return PyRepository.fetch_config(storage) @staticmethod @@ -454,6 +467,7 @@ def save_config(self) -> None: ------- None """ + _raise_if_running_loop("save_config", "save_config_async") return self._repository.save_config() async def save_config_async(self) -> None: @@ -522,6 +536,7 @@ def reopen( Self A new Repository instance with the updated configuration. """ + _raise_if_running_loop("reopen", "reopen_async") return self.__class__( self._repository.reopen( config=config, @@ -595,6 +610,7 @@ def get_metadata(self) -> dict[str, Any]: dict[str, Any] The repository level metadata. """ + _raise_if_running_loop("get_metadata", "get_metadata_async") return self._repository.get_metadata() @property @@ -631,6 +647,7 @@ def set_metadata(self, metadata: dict[str, Any]) -> None: metadata : dict[str, Any] The value to use as repository metadata. """ + _raise_if_running_loop("set_metadata", "set_metadata_async") self._repository.set_metadata(metadata) async def set_metadata_async(self, metadata: dict[str, Any]) -> None: @@ -657,6 +674,7 @@ def update_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: metadata : dict[str, Any] The dict to merge into the repository metadata. """ + _raise_if_running_loop("update_metadata", "update_metadata_async") return self._repository.update_metadata(metadata) async def update_metadata_async(self, metadata: dict[str, Any]) -> dict[str, Any]: @@ -700,6 +718,7 @@ def ancestry( ----- Only one of the arguments can be specified. """ + _raise_if_running_loop("ancestry", "async_ancestry", await_async=False) # the returned object is both an Async and Sync iterator res = cast( @@ -746,6 +765,7 @@ def ops_log(self) -> Iterator[UpdateType]: """ Get a summary of changes to the repository """ + _raise_if_running_loop("ops_log", "ops_log_async", await_async=False) # the returned object is both an Async and Sync iterator res = cast( @@ -777,6 +797,7 @@ def create_branch(self, branch: str, snapshot_id: str) -> None: ------- None """ + _raise_if_running_loop("create_branch", "create_branch_async") self._repository.create_branch(branch, snapshot_id) async def create_branch_async(self, branch: str, snapshot_id: str) -> None: @@ -833,6 +854,7 @@ def lookup_branch(self, branch: str) -> str: str The snapshot ID of the tip of the branch. """ + _raise_if_running_loop("lookup_branch", "lookup_branch_async") return self._repository.lookup_branch(branch) async def lookup_branch_async(self, branch: str) -> str: @@ -864,6 +886,7 @@ def lookup_snapshot(self, snapshot_id: str) -> SnapshotInfo: ------- SnapshotInfo """ + _raise_if_running_loop("lookup_snapshot", "lookup_snapshot_async") return self._repository.lookup_snapshot(snapshot_id) async def lookup_snapshot_async(self, snapshot_id: str) -> SnapshotInfo: @@ -894,6 +917,7 @@ def list_manifest_files(self, snapshot_id: str) -> list[ManifestFileInfo]: ------- list[ManifestFileInfo] """ + _raise_if_running_loop("list_manifest_files", "list_manifest_files_async") return self._repository.list_manifest_files(snapshot_id) async def list_manifest_files_async(self, snapshot_id: str) -> list[ManifestFileInfo]: @@ -934,6 +958,7 @@ def reset_branch( ------- None """ + _raise_if_running_loop("reset_branch", "reset_branch_async") self._repository.reset_branch(branch, snapshot_id, from_snapshot_id) async def reset_branch_async( @@ -974,6 +999,7 @@ def delete_branch(self, branch: str) -> None: ------- None """ + _raise_if_running_loop("delete_branch", "delete_branch_async") self._repository.delete_branch(branch) async def delete_branch_async(self, branch: str) -> None: @@ -1004,6 +1030,7 @@ def delete_tag(self, tag: str) -> None: ------- None """ + _raise_if_running_loop("delete_tag", "delete_tag_async") self._repository.delete_tag(tag) async def delete_tag_async(self, tag: str) -> None: @@ -1036,6 +1063,7 @@ def create_tag(self, tag: str, snapshot_id: str) -> None: ------- None """ + _raise_if_running_loop("create_tag", "create_tag_async") self._repository.create_tag(tag, snapshot_id) async def create_tag_async(self, tag: str, snapshot_id: str) -> None: @@ -1064,6 +1092,7 @@ def list_tags(self) -> set[str]: set[str] A set of tag names. """ + _raise_if_running_loop("list_tags", "list_tags_async") return self._repository.list_tags() async def list_tags_async(self) -> set[str]: @@ -1091,6 +1120,7 @@ def lookup_tag(self, tag: str) -> str: str The snapshot ID of the tag. """ + _raise_if_running_loop("lookup_tag", "lookup_tag_async") return self._repository.lookup_tag(tag) async def lookup_tag_async(self, tag: str) -> str: @@ -1132,6 +1162,7 @@ def diff( Diff The operations executed between the two versions """ + _raise_if_running_loop("diff", "diff_async") return self._repository.diff( from_branch=from_branch, from_tag=from_tag, @@ -1209,6 +1240,7 @@ def readonly_session( ----- Only one of the arguments can be specified. """ + _raise_if_running_loop("readonly_session", "readonly_session_async") return Session( self._repository.readonly_session( branch=branch, tag=tag, snapshot_id=snapshot_id, as_of=as_of @@ -1276,6 +1308,7 @@ def writable_session(self, branch: str) -> Session: Session The writable session on the branch. """ + _raise_if_running_loop("writable_session", "writable_session_async") return Session(self._repository.writable_session(branch)) async def writable_session_async(self, branch: str) -> Session: @@ -1321,6 +1354,7 @@ def rearrange_session(self, branch: str) -> Session: Session The writable session on the branch. """ + _raise_if_running_loop("rearrange_session", "rearrange_session_async") return Session(self._repository.rearrange_session(branch)) async def rearrange_session_async(self, branch: str) -> Session: @@ -1432,6 +1466,7 @@ def expire_snapshots( ------- set of expires snapshot IDs """ + _raise_if_running_loop("expire_snapshots", "expire_snapshots_async") return self._repository.expire_snapshots( older_than, delete_expired_branches=delete_expired_branches, @@ -1513,6 +1548,7 @@ def rewrite_manifests( The snapshot ID of the new commit. """ + _raise_if_running_loop("rewrite_manifests", "rewrite_manifests_async") return self._repository.rewrite_manifests( message, branch=branch, metadata=metadata ) @@ -1585,6 +1621,7 @@ def garbage_collect( GCSummary Summary of objects deleted. """ + _raise_if_running_loop("garbage_collect", "garbage_collect_async") return self._repository.garbage_collect( delete_object_older_than, @@ -1666,6 +1703,7 @@ def chunk_storage_stats( max_concurrent_manifest_fetches : int Don't run more than this many concurrent manifest fetches. """ + _raise_if_running_loop("chunk_storage_stats", "chunk_storage_stats_async") return self._repository.chunk_storage_stats( max_snapshots_in_memory=max_snapshots_in_memory, max_compressed_manifest_mem_bytes=max_compressed_manifest_mem_bytes, @@ -1731,6 +1769,7 @@ def total_chunks_storage( max_concurrent_manifest_fetches : int Don't run more than this many concurrent manifest fetches. """ + _raise_if_running_loop("total_chunks_storage", "total_chunks_storage_async") warnings.warn( "The ``total_chunks_storage`` method has been deprecated in favour of the ``chunk_storage_stats`` method. " @@ -1792,6 +1831,7 @@ async def total_chunks_storage_async( return stats.native_bytes def inspect_snapshot(self, snapshot_id: str, *, pretty: bool = True) -> str: + _raise_if_running_loop("inspect_snapshot", "inspect_snapshot_async") return self._repository.inspect_snapshot(snapshot_id, pretty=pretty) async def inspect_snapshot_async( diff --git a/icechunk-python/tests/test_credentials.py b/icechunk-python/tests/test_credentials.py index dbcc4bae7..530cc2541 100644 --- a/icechunk-python/tests/test_credentials.py +++ b/icechunk-python/tests/test_credentials.py @@ -382,13 +382,48 @@ async def test_sync_list_branches_in_async_context_errors( access_key_id="minio123", secret_access_key="minio123", ) - Repository.create(storage=create_storage, spec_version=any_spec_version) + await Repository.create_async( + storage=create_storage, spec_version=any_spec_version + ) repo = await Repository.open_async(create_storage) with pytest.raises(RuntimeError, match="list_branches_async"): repo.list_branches() +@pytest.mark.asyncio +async def test_sync_repository_apis_in_async_context_error_consistently( + any_spec_version: int | None, +) -> None: + prefix = "test_sync_repository_apis_in_async_context_error_consistently-" + str( + int(time.time() * 1000) + ) + + create_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + access_key_id="minio123", + secret_access_key="minio123", + ) + await Repository.create_async( + storage=create_storage, spec_version=any_spec_version + ) + repo = await Repository.open_async(create_storage) + + with pytest.raises(RuntimeError, match="exists_async"): + Repository.exists(create_storage) + with pytest.raises(RuntimeError, match="lookup_branch_async"): + repo.lookup_branch("main") + with pytest.raises(RuntimeError, match=r"Repository\.async_ancestry"): + repo.ancestry(branch="main") + with pytest.raises(RuntimeError, match=r"Repository\.ops_log_async"): + repo.ops_log() + + @pytest.mark.asyncio async def test_async_refreshable_credentials_with_async_repository_api( tmp_path: Path, any_spec_version: int | None @@ -405,7 +440,9 @@ async def test_async_refreshable_credentials_with_async_repository_api( access_key_id="minio123", secret_access_key="minio123", ) - Repository.create(storage=create_storage, spec_version=any_spec_version) + await Repository.create_async( + storage=create_storage, spec_version=any_spec_version + ) calls_path = tmp_path / "async_async_calls.txt" context_value = "expected-refresh-context" From 6f842c49025453fccc4f1c356610368586542b05 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Thu, 12 Feb 2026 17:07:54 -0500 Subject: [PATCH 05/14] Simplify --- .../python/icechunk/credentials.py | 2 +- icechunk-python/python/icechunk/repository.py | 53 -------------- icechunk-python/src/config.rs | 7 ++ icechunk-python/tests/test_credentials.py | 72 ++----------------- 4 files changed, 13 insertions(+), 121 deletions(-) diff --git a/icechunk-python/python/icechunk/credentials.py b/icechunk-python/python/icechunk/credentials.py index 7cd31fe9f..acb331597 100644 --- a/icechunk-python/python/icechunk/credentials.py +++ b/icechunk-python/python/icechunk/credentials.py @@ -71,8 +71,8 @@ async def _as_coroutine() -> CredentialType: try: return asyncio.run(coroutine) except RuntimeError as err: + coroutine.close() if "asyncio.run() cannot be called from a running event loop" in str(err): - coroutine.close() raise ValueError( "scatter_initial_credentials=True cannot eagerly evaluate async " "credential callbacks while an event loop is running. " diff --git a/icechunk-python/python/icechunk/repository.py b/icechunk-python/python/icechunk/repository.py index 6d59a4148..6ef06c074 100644 --- a/icechunk-python/python/icechunk/repository.py +++ b/icechunk-python/python/icechunk/repository.py @@ -1,4 +1,3 @@ -import asyncio import datetime import warnings from collections.abc import AsyncIterator, Iterator @@ -22,24 +21,6 @@ from icechunk.store import IcechunkStore -def _raise_if_running_loop( - sync_method: str, async_method: str, *, await_async: bool = True -) -> None: - try: - asyncio.get_running_loop() - except RuntimeError: - return - guidance = ( - f"Use `await Repository.{async_method}` instead." - if await_async - else f"Use `Repository.{async_method}` instead." - ) - raise RuntimeError( - f"`Repository.{sync_method}` cannot be called while an asyncio event loop is " - f"running. {guidance}" - ) - - class Repository: """An Icechunk repository.""" @@ -86,7 +67,6 @@ def create( Self An instance of the Repository class. """ - _raise_if_running_loop("create", "create_async") return cls( PyRepository.create( storage, @@ -179,7 +159,6 @@ def open( Self An instance of the Repository class. """ - _raise_if_running_loop("open", "open_async") return cls( PyRepository.open( storage, @@ -275,7 +254,6 @@ def open_or_create( Self An instance of the Repository class. """ - _raise_if_running_loop("open_or_create", "open_or_create_async") return cls( PyRepository.open_or_create( storage, @@ -351,7 +329,6 @@ def exists(storage: Storage) -> bool: bool True if the repository exists, False otherwise. """ - _raise_if_running_loop("exists", "exists_async") return PyRepository.exists(storage) @staticmethod @@ -390,7 +367,6 @@ def fetch_spec_version(storage: Storage) -> int | None: The spec version of the repository if it exists, None if no repository exists at the given location. """ - _raise_if_running_loop("fetch_spec_version", "fetch_spec_version_async") return PyRepository.fetch_spec_version(storage) @staticmethod @@ -439,7 +415,6 @@ def fetch_config(storage: Storage) -> RepositoryConfig | None: RepositoryConfig | None The repository configuration if it exists, None otherwise. """ - _raise_if_running_loop("fetch_config", "fetch_config_async") return PyRepository.fetch_config(storage) @staticmethod @@ -467,7 +442,6 @@ def save_config(self) -> None: ------- None """ - _raise_if_running_loop("save_config", "save_config_async") return self._repository.save_config() async def save_config_async(self) -> None: @@ -536,7 +510,6 @@ def reopen( Self A new Repository instance with the updated configuration. """ - _raise_if_running_loop("reopen", "reopen_async") return self.__class__( self._repository.reopen( config=config, @@ -610,7 +583,6 @@ def get_metadata(self) -> dict[str, Any]: dict[str, Any] The repository level metadata. """ - _raise_if_running_loop("get_metadata", "get_metadata_async") return self._repository.get_metadata() @property @@ -647,7 +619,6 @@ def set_metadata(self, metadata: dict[str, Any]) -> None: metadata : dict[str, Any] The value to use as repository metadata. """ - _raise_if_running_loop("set_metadata", "set_metadata_async") self._repository.set_metadata(metadata) async def set_metadata_async(self, metadata: dict[str, Any]) -> None: @@ -674,7 +645,6 @@ def update_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: metadata : dict[str, Any] The dict to merge into the repository metadata. """ - _raise_if_running_loop("update_metadata", "update_metadata_async") return self._repository.update_metadata(metadata) async def update_metadata_async(self, metadata: dict[str, Any]) -> dict[str, Any]: @@ -718,7 +688,6 @@ def ancestry( ----- Only one of the arguments can be specified. """ - _raise_if_running_loop("ancestry", "async_ancestry", await_async=False) # the returned object is both an Async and Sync iterator res = cast( @@ -765,7 +734,6 @@ def ops_log(self) -> Iterator[UpdateType]: """ Get a summary of changes to the repository """ - _raise_if_running_loop("ops_log", "ops_log_async", await_async=False) # the returned object is both an Async and Sync iterator res = cast( @@ -797,7 +765,6 @@ def create_branch(self, branch: str, snapshot_id: str) -> None: ------- None """ - _raise_if_running_loop("create_branch", "create_branch_async") self._repository.create_branch(branch, snapshot_id) async def create_branch_async(self, branch: str, snapshot_id: str) -> None: @@ -826,7 +793,6 @@ def list_branches(self) -> set[str]: set[str] A set of branch names. """ - _raise_if_running_loop("list_branches", "list_branches_async") return self._repository.list_branches() async def list_branches_async(self) -> set[str]: @@ -854,7 +820,6 @@ def lookup_branch(self, branch: str) -> str: str The snapshot ID of the tip of the branch. """ - _raise_if_running_loop("lookup_branch", "lookup_branch_async") return self._repository.lookup_branch(branch) async def lookup_branch_async(self, branch: str) -> str: @@ -886,7 +851,6 @@ def lookup_snapshot(self, snapshot_id: str) -> SnapshotInfo: ------- SnapshotInfo """ - _raise_if_running_loop("lookup_snapshot", "lookup_snapshot_async") return self._repository.lookup_snapshot(snapshot_id) async def lookup_snapshot_async(self, snapshot_id: str) -> SnapshotInfo: @@ -917,7 +881,6 @@ def list_manifest_files(self, snapshot_id: str) -> list[ManifestFileInfo]: ------- list[ManifestFileInfo] """ - _raise_if_running_loop("list_manifest_files", "list_manifest_files_async") return self._repository.list_manifest_files(snapshot_id) async def list_manifest_files_async(self, snapshot_id: str) -> list[ManifestFileInfo]: @@ -958,7 +921,6 @@ def reset_branch( ------- None """ - _raise_if_running_loop("reset_branch", "reset_branch_async") self._repository.reset_branch(branch, snapshot_id, from_snapshot_id) async def reset_branch_async( @@ -999,7 +961,6 @@ def delete_branch(self, branch: str) -> None: ------- None """ - _raise_if_running_loop("delete_branch", "delete_branch_async") self._repository.delete_branch(branch) async def delete_branch_async(self, branch: str) -> None: @@ -1030,7 +991,6 @@ def delete_tag(self, tag: str) -> None: ------- None """ - _raise_if_running_loop("delete_tag", "delete_tag_async") self._repository.delete_tag(tag) async def delete_tag_async(self, tag: str) -> None: @@ -1063,7 +1023,6 @@ def create_tag(self, tag: str, snapshot_id: str) -> None: ------- None """ - _raise_if_running_loop("create_tag", "create_tag_async") self._repository.create_tag(tag, snapshot_id) async def create_tag_async(self, tag: str, snapshot_id: str) -> None: @@ -1092,7 +1051,6 @@ def list_tags(self) -> set[str]: set[str] A set of tag names. """ - _raise_if_running_loop("list_tags", "list_tags_async") return self._repository.list_tags() async def list_tags_async(self) -> set[str]: @@ -1120,7 +1078,6 @@ def lookup_tag(self, tag: str) -> str: str The snapshot ID of the tag. """ - _raise_if_running_loop("lookup_tag", "lookup_tag_async") return self._repository.lookup_tag(tag) async def lookup_tag_async(self, tag: str) -> str: @@ -1162,7 +1119,6 @@ def diff( Diff The operations executed between the two versions """ - _raise_if_running_loop("diff", "diff_async") return self._repository.diff( from_branch=from_branch, from_tag=from_tag, @@ -1240,7 +1196,6 @@ def readonly_session( ----- Only one of the arguments can be specified. """ - _raise_if_running_loop("readonly_session", "readonly_session_async") return Session( self._repository.readonly_session( branch=branch, tag=tag, snapshot_id=snapshot_id, as_of=as_of @@ -1308,7 +1263,6 @@ def writable_session(self, branch: str) -> Session: Session The writable session on the branch. """ - _raise_if_running_loop("writable_session", "writable_session_async") return Session(self._repository.writable_session(branch)) async def writable_session_async(self, branch: str) -> Session: @@ -1354,7 +1308,6 @@ def rearrange_session(self, branch: str) -> Session: Session The writable session on the branch. """ - _raise_if_running_loop("rearrange_session", "rearrange_session_async") return Session(self._repository.rearrange_session(branch)) async def rearrange_session_async(self, branch: str) -> Session: @@ -1466,7 +1419,6 @@ def expire_snapshots( ------- set of expires snapshot IDs """ - _raise_if_running_loop("expire_snapshots", "expire_snapshots_async") return self._repository.expire_snapshots( older_than, delete_expired_branches=delete_expired_branches, @@ -1548,7 +1500,6 @@ def rewrite_manifests( The snapshot ID of the new commit. """ - _raise_if_running_loop("rewrite_manifests", "rewrite_manifests_async") return self._repository.rewrite_manifests( message, branch=branch, metadata=metadata ) @@ -1621,7 +1572,6 @@ def garbage_collect( GCSummary Summary of objects deleted. """ - _raise_if_running_loop("garbage_collect", "garbage_collect_async") return self._repository.garbage_collect( delete_object_older_than, @@ -1703,7 +1653,6 @@ def chunk_storage_stats( max_concurrent_manifest_fetches : int Don't run more than this many concurrent manifest fetches. """ - _raise_if_running_loop("chunk_storage_stats", "chunk_storage_stats_async") return self._repository.chunk_storage_stats( max_snapshots_in_memory=max_snapshots_in_memory, max_compressed_manifest_mem_bytes=max_compressed_manifest_mem_bytes, @@ -1769,7 +1718,6 @@ def total_chunks_storage( max_concurrent_manifest_fetches : int Don't run more than this many concurrent manifest fetches. """ - _raise_if_running_loop("total_chunks_storage", "total_chunks_storage_async") warnings.warn( "The ``total_chunks_storage`` method has been deprecated in favour of the ``chunk_storage_stats`` method. " @@ -1831,7 +1779,6 @@ async def total_chunks_storage_async( return stats.native_bytes def inspect_snapshot(self, snapshot_id: str, *, pretty: bool = True) -> str: - _raise_if_running_loop("inspect_snapshot", "inspect_snapshot_async") return self._repository.inspect_snapshot(snapshot_id, pretty=pretty) async def inspect_snapshot_async( diff --git a/icechunk-python/src/config.rs b/icechunk-python/src/config.rs index 2be39de9e..568a751df 100644 --- a/icechunk-python/src/config.rs +++ b/icechunk-python/src/config.rs @@ -141,6 +141,9 @@ pub(crate) fn datetime_repr(d: &DateTime) -> String { struct PythonCredentialsFetcher { pub pickled_function: Vec, pub initial: Option, + // Intentionally skipped during serialization — on deserialization, + // `await_python_callback` calls `current_task_locals()` to get + // fresh locals from whatever event loop is active at that point. #[serde(skip, default)] pub task_locals: Option, } @@ -224,6 +227,10 @@ where return Python::attach(|py| value.bind(py).extract().map_err(Into::into)); } + // No task locals available — no Python event loop is currently running. + // Fall back to asyncio.run() which creates a temporary event loop to + // drive the awaitable. This is the correct path when an async credential + // callback is used from a sync context (e.g. Repository.open). Python::attach(|py| { let asyncio = PyModule::import(py, "asyncio")?; let value = asyncio.getattr("run")?.call1((awaitable.bind(py),))?; diff --git a/icechunk-python/tests/test_credentials.py b/icechunk-python/tests/test_credentials.py index 530cc2541..582cfab8d 100644 --- a/icechunk-python/tests/test_credentials.py +++ b/icechunk-python/tests/test_credentials.py @@ -281,7 +281,7 @@ def test_async_refreshable_credentials_with_sync_repository_api( repo = Repository.open(callback_storage) assert "main" in repo.list_branches() - assert calls_path.read_text() != "" + assert "." in calls_path.read_text() def test_async_refreshable_credentials_constructed_sync_used_async( @@ -319,7 +319,7 @@ async def use_async_repository_api() -> None: assert "main" in await repo.list_branches_async() asyncio.run(use_async_repository_api()) - assert calls_path.read_text() != "" + assert "." in calls_path.read_text() def test_async_refreshable_credentials_repo_reused_across_event_loops( @@ -361,67 +361,7 @@ async def use_repo_on_different_loop() -> None: assert "main" in await repo.list_branches_async() asyncio.run(use_repo_on_different_loop()) - assert calls_path.read_text() != "" - - -@pytest.mark.asyncio -async def test_sync_list_branches_in_async_context_errors( - any_spec_version: int | None, -) -> None: - prefix = "test_sync_list_branches_in_async_context_errors-" + str( - int(time.time() * 1000) - ) - - create_storage = s3_storage( - region="us-east-1", - endpoint_url="http://localhost:9000", - allow_http=True, - force_path_style=True, - bucket="testbucket", - prefix=prefix, - access_key_id="minio123", - secret_access_key="minio123", - ) - await Repository.create_async( - storage=create_storage, spec_version=any_spec_version - ) - - repo = await Repository.open_async(create_storage) - with pytest.raises(RuntimeError, match="list_branches_async"): - repo.list_branches() - - -@pytest.mark.asyncio -async def test_sync_repository_apis_in_async_context_error_consistently( - any_spec_version: int | None, -) -> None: - prefix = "test_sync_repository_apis_in_async_context_error_consistently-" + str( - int(time.time() * 1000) - ) - - create_storage = s3_storage( - region="us-east-1", - endpoint_url="http://localhost:9000", - allow_http=True, - force_path_style=True, - bucket="testbucket", - prefix=prefix, - access_key_id="minio123", - secret_access_key="minio123", - ) - await Repository.create_async( - storage=create_storage, spec_version=any_spec_version - ) - repo = await Repository.open_async(create_storage) - - with pytest.raises(RuntimeError, match="exists_async"): - Repository.exists(create_storage) - with pytest.raises(RuntimeError, match="lookup_branch_async"): - repo.lookup_branch("main") - with pytest.raises(RuntimeError, match=r"Repository\.async_ancestry"): - repo.ancestry(branch="main") - with pytest.raises(RuntimeError, match=r"Repository\.ops_log_async"): - repo.ops_log() + assert "." in calls_path.read_text() @pytest.mark.asyncio @@ -440,9 +380,7 @@ async def test_async_refreshable_credentials_with_async_repository_api( access_key_id="minio123", secret_access_key="minio123", ) - await Repository.create_async( - storage=create_storage, spec_version=any_spec_version - ) + await Repository.create_async(storage=create_storage, spec_version=any_spec_version) calls_path = tmp_path / "async_async_calls.txt" context_value = "expected-refresh-context" @@ -465,4 +403,4 @@ async def test_async_refreshable_credentials_with_async_repository_api( finally: ASYNC_CREDENTIALS_CONTEXT.reset(reset_token) - assert calls_path.read_text() != "" + assert "." in calls_path.read_text() From a1e4779a4fd9bece7d1b77ebca9dd300a11420d7 Mon Sep 17 00:00:00 2001 From: Samantha Hughes Date: Fri, 13 Feb 2026 14:47:19 -0800 Subject: [PATCH 06/14] add some tests related to asyncio loop consistency + deadlocks --- icechunk-python/tests/test_credentials.py | 198 ++++++++++++++++++++++ 1 file changed, 198 insertions(+) diff --git a/icechunk-python/tests/test_credentials.py b/icechunk-python/tests/test_credentials.py index 582cfab8d..f4417bbbc 100644 --- a/icechunk-python/tests/test_credentials.py +++ b/icechunk-python/tests/test_credentials.py @@ -1,6 +1,7 @@ import asyncio import contextvars import pickle +import threading import time from datetime import UTC, datetime from pathlib import Path @@ -404,3 +405,200 @@ async def test_async_refreshable_credentials_with_async_repository_api( ASYNC_CREDENTIALS_CONTEXT.reset(reset_token) assert "." in calls_path.read_text() + + +class GoodException(Exception): + pass + + +class BadException(Exception): + pass + + +# survives pickle because pickle imports the module, doesn't recreate globals +_cred_refresh_loop_ids: list[int] = [] + + +@pytest.fixture(autouse=False) +def reset_cred_tracker(): + _cred_refresh_loop_ids.clear() + yield _cred_refresh_loop_ids + _cred_refresh_loop_ids.clear() + + +@pytest.fixture() +def minio_repo_prefix(any_spec_version: int | None) -> str: + prefix = "test_async_creds-" + str(int(time.time() * 1000)) + create_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + access_key_id="minio123", + secret_access_key="minio123", + ) + Repository.create(storage=create_storage, spec_version=any_spec_version) + return prefix + + +class AsyncCredentialTracker: + def __init__( + self, + expected_loop_id: int | None = None, + return_creds: bool = False, + ): + self.expected_loop_id = expected_loop_id + self.return_creds = return_creds + + async def __call__(self) -> S3StaticCredentials: + await asyncio.sleep(0) + loop_id = id(asyncio.get_running_loop()) + _cred_refresh_loop_ids.append(loop_id) + if self.expected_loop_id is not None and loop_id != self.expected_loop_id: + raise BadException("Wrong event loop") + if self.return_creds: + return S3StaticCredentials( + access_key_id="minio123", + secret_access_key="minio123", + expires_after=datetime.now(UTC), + ) + raise GoodException("YOLO") + + +def test_async_cred_refresh_uses_same_loop( + minio_repo_prefix: str, reset_cred_tracker: list[int] +) -> None: + prefix = minio_repo_prefix + + async def run() -> None: + loop = asyncio.get_running_loop() + callback_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + get_credentials=AsyncCredentialTracker( + expected_loop_id=id(loop), return_creds=True + ), + ) + repo = await Repository.open_async(callback_storage) + assert "main" in await repo.list_branches_async() + # FIXME: this deadlocks + # _ = [snap async for snap in repo.async_ancestry(branch="main")] + snap = await repo.lookup_branch_async("main") + await repo.create_branch_async("test-write", snap) + # all that should have triggered at least one refresh + assert len(_cred_refresh_loop_ids) > 1 + assert len(set(_cred_refresh_loop_ids)) == 1 + + asyncio.run(run()) + + +def test_async_cred_refresh_uses_originator_loop_from_thread( + minio_repo_prefix: str, reset_cred_tracker: list[int] +) -> None: + prefix = minio_repo_prefix + + async def run() -> None: + loop = asyncio.get_running_loop() + callback_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + get_credentials=AsyncCredentialTracker( + expected_loop_id=id(loop), return_creds=True + ), + ) + + def sync_in_thread() -> None: + repo = Repository.open(callback_storage) + calls_after_open = len(_cred_refresh_loop_ids) + assert calls_after_open >= 1 + + repo.list_branches() + calls_after_list_branches = len(_cred_refresh_loop_ids) + assert calls_after_list_branches > calls_after_open + + list(repo.ancestry(branch="main")) + calls_after_ancestry = len(_cred_refresh_loop_ids) + assert calls_after_ancestry > calls_after_list_branches + + await loop.run_in_executor(None, sync_in_thread) + assert len(set(_cred_refresh_loop_ids)) == 1 + + asyncio.run(run()) + + +def test_async_cred_refresh_graceful_deadlock(): + # Deadlock scenario: sync list_branches called directly on the event loop + # thread where the credential callback would also need to run. Use a thread + # with a timeout so the test doesn't hang forever if it actually deadlocks. + caught_error: BaseException | None = None + + async def sync_call_on_same_loop_as_callback() -> None: + loop = asyncio.get_running_loop() + current_loop_id = id(loop) + storage = s3_storage( + region="us-east-1", + bucket="testbucket", + prefix="placeholder", + get_credentials=AsyncCredentialTracker(expected_loop_id=current_loop_id), + ) + # A user can very easily make this mistake! + # we should be gracefully falling over not deadlocking + with pytest.raises(IcechunkError, match="deadlock"): + Repository.open(storage=storage) + + def run_deadlock_check() -> None: + nonlocal caught_error + try: + asyncio.run(sync_call_on_same_loop_as_callback()) + except BaseException as e: + caught_error = e + + t = threading.Thread(target=run_deadlock_check, daemon=True) + t.start() + t.join(timeout=1) + assert not t.is_alive(), "Deadlocked: sync call blocked the event loop thread" + assert isinstance(caught_error, IcechunkError) + assert caught_error == "YOLO", (str(caught_error), type(caught_error)) + + +def test_async_callback_no_loop_has_consistent_loop( + minio_repo_prefix: str, reset_cred_tracker: list[int] +) -> None: + prefix = minio_repo_prefix + + callback_storage = s3_storage( + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, + force_path_style=True, + bucket="testbucket", + prefix=prefix, + get_credentials=AsyncCredentialTracker(return_creds=True), + ) + + repo = Repository.open(callback_storage) + repo.list_branches() + list(repo.ancestry(branch="main")) + assert len(_cred_refresh_loop_ids) > 3 + + def on_another_thread(r: Repository) -> None: + r.list_branches() + + t = threading.Thread(target=on_another_thread, args=(repo,)) + t.start() + t.join() + assert len(_cred_refresh_loop_ids) > 4 + assert len(set(_cred_refresh_loop_ids)) == 1 + assert len(set(_cred_refresh_loop_ids)) == 1, ( + "All credential refresh calls should be on the same loop even if there isn't an event loop at the start" + ) From 93ce1931a98d5e1fcb5b2e50f382ec8c17eab9a1 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Mon, 16 Feb 2026 13:17:16 -0500 Subject: [PATCH 07/14] async cred callback: fallback event loop + deadlock detection - When no Python event loop is running, create a persistent fallback loop on a background thread instead of calling asyncio.run() each time (which creates a new temporary loop per invocation). - Detect potential deadlocks when the credential callback targets the same event loop that is currently blocked, and raise an error instead of hanging. Co-Authored-By: Claude Opus 4.6 --- icechunk-python/src/config.rs | 110 +++++++++++++++++++++++++++------- 1 file changed, 89 insertions(+), 21 deletions(-) diff --git a/icechunk-python/src/config.rs b/icechunk-python/src/config.rs index 568a751df..82837fcf2 100644 --- a/icechunk-python/src/config.rs +++ b/icechunk-python/src/config.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use chrono::{DateTime, Datelike, TimeDelta, Timelike, Utc}; use icechunk::storage::RetriesSettings; use itertools::Itertools; -use pyo3::exceptions::PyValueError; +use pyo3::exceptions::{PyRuntimeError, PyValueError}; use serde::{Deserialize, Serialize}; use std::hash::{Hash, Hasher}; use std::{ @@ -11,7 +11,7 @@ use std::{ hash::DefaultHasher, num::{NonZeroU16, NonZeroU64}, path::PathBuf, - sync::Arc, + sync::{Arc, OnceLock}, }; use icechunk::{ @@ -173,6 +173,74 @@ fn current_task_locals() -> Option { Python::attach(|py| pyo3_async_runtimes::tokio::get_current_locals(py).ok()) } +static FALLBACK_TASK_LOCALS: OnceLock> = + OnceLock::new(); + +fn object_id(py: Python<'_>, obj: &Bound<'_, PyAny>) -> Result { + let builtins = PyModule::import(py, "builtins")?; + builtins.getattr("id")?.call1((obj,))?.extract() +} + +fn would_deadlock_current_loop( + task_locals: &pyo3_async_runtimes::TaskLocals, +) -> Result { + Python::attach(|py| { + let running_loop = match pyo3_async_runtimes::get_running_loop(py) { + Ok(loop_ref) => loop_ref, + Err(err) if err.is_instance_of::(py) => return Ok(false), + Err(err) => return Err(err), + }; + + let target_loop = task_locals.event_loop(py); + Ok(object_id(py, &running_loop)? == object_id(py, &target_loop)?) + }) +} + +fn create_fallback_task_locals() -> Result { + let (tx, rx) = std::sync::mpsc::sync_channel::< + Result, + >(1); + + std::thread::Builder::new() + .name("icechunk-python-asyncio-fallback".to_string()) + .spawn(move || { + let mut tx = Some(tx); + let run_result = Python::attach(|py| -> PyResult<()> { + let asyncio = PyModule::import(py, "asyncio")?; + let event_loop = asyncio.getattr("new_event_loop")?.call0()?; + asyncio.getattr("set_event_loop")?.call1((event_loop.clone(),))?; + + let locals = pyo3_async_runtimes::TaskLocals::new(event_loop.clone()); + let _ = tx.take().expect("sender must exist").send(Ok(locals)); + + event_loop.call_method0("run_forever")?; + Ok(()) + }); + + if let Err(err) = run_result { + if let Some(sender) = tx { + let _ = sender.send(Err(err.to_string())); + } + } + }) + .map_err(|err| format!("failed to start fallback asyncio loop thread: {err}"))?; + + match rx.recv() { + Ok(Ok(locals)) => Ok(locals), + Ok(Err(err)) => Err(format!("failed to initialize fallback asyncio loop: {err}")), + Err(err) => Err(format!( + "fallback asyncio loop thread terminated before initialization: {err}" + )), + } +} + +fn fallback_task_locals() -> Result { + match FALLBACK_TASK_LOCALS.get_or_init(create_fallback_task_locals) { + Ok(locals) => Ok(locals.clone()), + Err(err) => Err(PyRuntimeError::new_err(err.clone())), + } +} + enum PythonCallbackResult { Value(PyCred), Awaitable(Py), @@ -216,26 +284,26 @@ where PyCred: for<'a, 'py> FromPyObject<'a, 'py>, for<'a, 'py> >::Error: Into, { - if let Some(task_locals) = current_task_locals().or(task_locals) { - let fut = Python::attach(|py| { - pyo3_async_runtimes::into_future_with_locals( - &task_locals, - awaitable.bind(py).clone(), - ) - })?; - let value = fut.await?; - return Python::attach(|py| value.bind(py).extract().map_err(Into::into)); - } + let task_locals = match current_task_locals().or(task_locals) { + Some(task_locals) => { + if would_deadlock_current_loop(&task_locals)? { + return Err(PyValueError::new_err( + "deadlock: async credential callback targets the currently blocked event loop thread", + )); + } + task_locals + } + None => fallback_task_locals()?, + }; - // No task locals available — no Python event loop is currently running. - // Fall back to asyncio.run() which creates a temporary event loop to - // drive the awaitable. This is the correct path when an async credential - // callback is used from a sync context (e.g. Repository.open). - Python::attach(|py| { - let asyncio = PyModule::import(py, "asyncio")?; - let value = asyncio.getattr("run")?.call1((awaitable.bind(py),))?; - value.extract().map_err(Into::into) - }) + let fut = Python::attach(|py| { + pyo3_async_runtimes::into_future_with_locals( + &task_locals, + awaitable.bind(py).clone(), + ) + })?; + let value = fut.await?; + Python::attach(|py| value.bind(py).extract().map_err(Into::into)) } #[async_trait] From b33d134f14bd3edd628c8420464b92118257397e Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Mon, 16 Feb 2026 13:22:39 -0500 Subject: [PATCH 08/14] Lint --- icechunk-python/src/config.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/icechunk-python/src/config.rs b/icechunk-python/src/config.rs index 82837fcf2..9ba7d3e0c 100644 --- a/icechunk-python/src/config.rs +++ b/icechunk-python/src/config.rs @@ -217,10 +217,10 @@ fn create_fallback_task_locals() -> Result Date: Mon, 16 Feb 2026 14:58:59 -0500 Subject: [PATCH 09/14] Detecting event lops --- icechunk-python/src/lib.rs | 1 + icechunk-python/src/repository.rs | 35 +++++++++++++++++++++++ icechunk-python/src/session.rs | 14 +++++++++ icechunk-python/src/sync_api.rs | 14 +++++++++ icechunk-python/tests/test_credentials.py | 8 +++--- 5 files changed, 68 insertions(+), 4 deletions(-) create mode 100644 icechunk-python/src/sync_api.rs diff --git a/icechunk-python/src/lib.rs b/icechunk-python/src/lib.rs index 4426fcc26..c68b8205c 100644 --- a/icechunk-python/src/lib.rs +++ b/icechunk-python/src/lib.rs @@ -7,6 +7,7 @@ mod session; mod stats; mod store; mod streams; +mod sync_api; use std::env; diff --git a/icechunk-python/src/repository.rs b/icechunk-python/src/repository.rs index 8b9fb20d6..28a36b9e5 100644 --- a/icechunk-python/src/repository.rs +++ b/icechunk-python/src/repository.rs @@ -47,6 +47,7 @@ use crate::{ session::PySession, stats::PyChunkStorageStats, streams::PyAsyncGenerator, + sync_api::ensure_not_running_event_loop, }; /// Wrapper needed to implement pyo3 conversion classes @@ -815,6 +816,7 @@ impl PyRepository { dry_run: bool, delete_unused_v1_files: bool, ) -> PyResult<()> { + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let mut repo = self.0.write().await; @@ -842,6 +844,7 @@ impl PyRepository { spec_version: Option, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let repository = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { @@ -912,6 +915,7 @@ impl PyRepository { authorize_virtual_chunk_access: Option>>, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let repository = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { @@ -964,6 +968,7 @@ impl PyRepository { create_version: Option, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let repository = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { @@ -1028,6 +1033,7 @@ impl PyRepository { #[staticmethod] fn exists(py: Python<'_>, storage: PyStorage) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let exists = Repository::exists(storage.0) @@ -1051,6 +1057,7 @@ impl PyRepository { #[staticmethod] fn fetch_spec_version(py: Python<'_>, storage: PyStorage) -> PyResult> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let spec_version = Repository::fetch_spec_version(storage.0) @@ -1088,6 +1095,7 @@ impl PyRepository { Option>>, >, ) -> PyResult { + ensure_not_running_event_loop(py)?; py.detach(move || { let config = config .map(|c| c.try_into().map_err(PyValueError::new_err)) @@ -1140,6 +1148,7 @@ impl PyRepository { bytes: Vec, ) -> PyResult { // This is a compute intensive task, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { let repository = Repository::from_bytes(bytes) .map_err(PyIcechunkStoreError::RepositoryError)?; @@ -1149,6 +1158,7 @@ impl PyRepository { fn as_bytes(&self, py: Python<'_>) -> PyResult> { // This is a compute intensive task, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { let bytes = self .0 @@ -1165,6 +1175,7 @@ impl PyRepository { storage: PyStorage, ) -> PyResult> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let res = Repository::fetch_config(storage.0) @@ -1193,6 +1204,7 @@ impl PyRepository { fn save_config(&self, py: Python<'_>) -> PyResult<()> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let _etag = self @@ -1259,6 +1271,7 @@ impl PyRepository { } pub(crate) fn get_metadata(&self, py: Python<'_>) -> PyResult { + ensure_not_running_event_loop(py)?; py.detach(move || { let metadata = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { @@ -1294,6 +1307,7 @@ impl PyRepository { py: Python<'_>, metadata: PySnapshotProperties, ) -> PyResult<()> { + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { self.0 @@ -1328,6 +1342,7 @@ impl PyRepository { py: Python<'_>, metadata: PySnapshotProperties, ) -> PyResult { + ensure_not_running_event_loop(py)?; py.detach(move || { let res = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { self.0 @@ -1368,6 +1383,7 @@ impl PyRepository { snapshot_id: Option, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let version = args_to_version_info(branch, tag, snapshot_id, None)?; let ancestry = pyo3_async_runtimes::tokio::get_runtime() @@ -1392,6 +1408,7 @@ impl PyRepository { pub(crate) fn async_ops_log(&self, py: Python<'_>) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let ops = pyo3_async_runtimes::tokio::get_runtime() .block_on(async move { @@ -1417,6 +1434,7 @@ impl PyRepository { snapshot_id: &str, ) -> PyResult<()> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let snapshot_id = SnapshotId::try_from(snapshot_id).map_err(|_| { PyIcechunkStoreError::RepositoryError( @@ -1462,6 +1480,7 @@ impl PyRepository { pub(crate) fn list_branches(&self, py: Python<'_>) -> PyResult> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let branches = self @@ -1500,6 +1519,7 @@ impl PyRepository { branch_name: &str, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let tip = self @@ -1537,6 +1557,7 @@ impl PyRepository { snapshot_id: &str, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let snapshot_id = SnapshotId::try_from(snapshot_id).map_err(|_| { PyIcechunkStoreError::RepositoryError( @@ -1634,6 +1655,7 @@ impl PyRepository { from_snapshot_id: Option<&str>, ) -> PyResult<()> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let to_snapshot_id = SnapshotId::try_from(to_snapshot_id).map_err(|_| { PyIcechunkStoreError::RepositoryError( @@ -1701,6 +1723,7 @@ impl PyRepository { pub(crate) fn delete_branch(&self, py: Python<'_>, branch: &str) -> PyResult<()> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { self.0 @@ -1733,6 +1756,7 @@ impl PyRepository { pub(crate) fn delete_tag(&self, py: Python<'_>, tag: &str) -> PyResult<()> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { self.0 @@ -1770,6 +1794,7 @@ impl PyRepository { snapshot_id: &str, ) -> PyResult<()> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let snapshot_id = SnapshotId::try_from(snapshot_id).map_err(|_| { PyIcechunkStoreError::RepositoryError( @@ -1815,6 +1840,7 @@ impl PyRepository { pub(crate) fn list_tags(&self, py: Python<'_>) -> PyResult> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let tags = self @@ -1846,6 +1872,7 @@ impl PyRepository { pub(crate) fn lookup_tag(&self, py: Python<'_>, tag: &str) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let tag = self @@ -1893,6 +1920,7 @@ impl PyRepository { let to = args_to_version_info(to_branch, to_tag, to_snapshot_id, None)?; // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let diff = self @@ -1943,6 +1971,7 @@ impl PyRepository { as_of: Option>, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let version = args_to_version_info(branch, tag, snapshot_id, as_of)?; let session = @@ -1987,6 +2016,7 @@ impl PyRepository { branch: &str, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let session = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { @@ -2025,6 +2055,7 @@ impl PyRepository { branch: &str, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let session = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { @@ -2066,6 +2097,7 @@ impl PyRepository { metadata: Option, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let metadata = metadata.map(|m| m.into()); let result = @@ -2128,6 +2160,7 @@ impl PyRepository { delete_expired_tags: bool, ) -> PyResult> { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let result = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { @@ -2210,6 +2243,7 @@ impl PyRepository { max_concurrent_manifest_fetches: NonZeroU16, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let result = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { @@ -2275,6 +2309,7 @@ impl PyRepository { max_concurrent_manifest_fetches: NonZeroU16, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress + ensure_not_running_event_loop(py)?; py.detach(move || { let stats = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { diff --git a/icechunk-python/src/session.rs b/icechunk-python/src/session.rs index b5e8f0812..21fed6202 100644 --- a/icechunk-python/src/session.rs +++ b/icechunk-python/src/session.rs @@ -21,6 +21,7 @@ use crate::{ repository::{PyDiff, PySnapshotProperties}, store::PyStore, streams::PyAsyncGenerator, + sync_api::ensure_not_running_event_loop, }; #[pyclass] @@ -66,6 +67,7 @@ impl PySession { bytes: Vec, ) -> PyResult { // This is a compute intensive task, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { let session = Session::from_bytes(bytes).map_err(PyIcechunkStoreError::SessionError)?; @@ -79,6 +81,7 @@ impl PySession { fn as_bytes(&self, py: Python<'_>) -> PyIcechunkStoreResult> { // This is a compute intensive task, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { let bytes = self.0.blocking_read().as_bytes().map_err(PyIcechunkStoreError::from)?; @@ -118,6 +121,7 @@ impl PySession { pub fn status(&self, py: Python<'_>) -> PyResult { // This is blocking function, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { let session = self.0.blocking_read(); @@ -131,6 +135,7 @@ impl PySession { pub fn discard_changes(&self, py: Python<'_>) -> PyResult<()> { // This is blocking function, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { self.0 .blocking_write() @@ -152,6 +157,7 @@ impl PySession { let to = Path::new(to_path.as_str()) .map_err(|e| StoreError::from(StoreErrorKind::PathError(e))) .map_err(PyIcechunkStoreError::StoreError)?; + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async move { let mut session = self.0.write().await; @@ -245,6 +251,7 @@ impl PySession { #[getter] pub fn store(&self, py: Python<'_>) -> PyResult { // This is blocking function, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { let session = self.0.blocking_read(); let conc = session.config().get_partial_values_concurrency(); @@ -258,6 +265,7 @@ impl PySession { #[getter] pub fn config(&self, py: Python<'_>) -> PyResult { // This is blocking function, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { let session = self.0.blocking_read(); let config = session.config().clone().into(); @@ -267,6 +275,7 @@ impl PySession { pub fn all_virtual_chunk_locations(&self, py: Python<'_>) -> PyResult> { // This is blocking function, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { let session = self.0.blocking_read(); @@ -377,6 +386,7 @@ impl PySession { pub fn merge(&self, other: &PySession, py: Python<'_>) -> PyResult<()> { // This is blocking function, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { // TODO: bad clone let other = other.0.blocking_read().deref().clone(); @@ -421,6 +431,7 @@ impl PySession { ) -> PyResult { let metadata = metadata.map(|m| m.into()); // This is blocking function, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async { let mut session = self.0.write().await; @@ -489,6 +500,7 @@ impl PySession { ) -> PyResult { let metadata = metadata.map(|m| m.into()); // This is blocking function, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async { let mut session = self.0.write().await; @@ -532,6 +544,7 @@ impl PySession { ) -> PyResult { let metadata = metadata.map(|m| m.into()); // This is blocking function, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { pyo3_async_runtimes::tokio::get_runtime().block_on(async { let mut session = self.0.write().await; @@ -566,6 +579,7 @@ impl PySession { pub fn rebase(&self, solver: PyConflictSolver, py: Python<'_>) -> PyResult<()> { // This is blocking function, we need to release the Gil + ensure_not_running_event_loop(py)?; py.detach(move || { let solver = solver.as_ref(); pyo3_async_runtimes::tokio::get_runtime().block_on(async { diff --git a/icechunk-python/src/sync_api.rs b/icechunk-python/src/sync_api.rs new file mode 100644 index 000000000..d049b5b96 --- /dev/null +++ b/icechunk-python/src/sync_api.rs @@ -0,0 +1,14 @@ +use pyo3::{ + PyResult, Python, + exceptions::{PyRuntimeError, PyValueError}, +}; + +pub(crate) fn ensure_not_running_event_loop(py: Python<'_>) -> PyResult<()> { + match pyo3_async_runtimes::get_running_loop(py) { + Ok(_) => Err(PyValueError::new_err( + "deadlock: synchronous API called from a running event loop thread; use the async API or run the sync call in a worker thread", + )), + Err(err) if err.is_instance_of::(py) => Ok(()), + Err(err) => Err(err), + } +} diff --git a/icechunk-python/tests/test_credentials.py b/icechunk-python/tests/test_credentials.py index f4417bbbc..dee483c6f 100644 --- a/icechunk-python/tests/test_credentials.py +++ b/icechunk-python/tests/test_credentials.py @@ -553,8 +553,8 @@ async def sync_call_on_same_loop_as_callback() -> None: ) # A user can very easily make this mistake! # we should be gracefully falling over not deadlocking - with pytest.raises(IcechunkError, match="deadlock"): - Repository.open(storage=storage) + # with pytest.raises(IcechunkError, match="deadlock"): + Repository.open(storage=storage) def run_deadlock_check() -> None: nonlocal caught_error @@ -567,8 +567,8 @@ def run_deadlock_check() -> None: t.start() t.join(timeout=1) assert not t.is_alive(), "Deadlocked: sync call blocked the event loop thread" - assert isinstance(caught_error, IcechunkError) - assert caught_error == "YOLO", (str(caught_error), type(caught_error)) + assert isinstance(caught_error, IcechunkError | ValueError) + assert "deadlock" in str(caught_error) def test_async_callback_no_loop_has_consistent_loop( From 9875113576f9b3b54625d0d13ea2241978c0ed0f Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Mon, 16 Feb 2026 15:15:10 -0500 Subject: [PATCH 10/14] Cleanup --- icechunk-python/src/config.rs | 22 +------------- icechunk-python/src/lib.rs | 2 +- icechunk-python/src/repository.rs | 27 +++++++++++------ icechunk-python/src/session.rs | 18 +++++++----- icechunk-python/src/sync.rs | 35 +++++++++++++++++++++++ icechunk-python/src/sync_api.rs | 14 --------- icechunk-python/tests/test_credentials.py | 2 +- 7 files changed, 67 insertions(+), 53 deletions(-) create mode 100644 icechunk-python/src/sync.rs delete mode 100644 icechunk-python/src/sync_api.rs diff --git a/icechunk-python/src/config.rs b/icechunk-python/src/config.rs index 9ba7d3e0c..1e47770bf 100644 --- a/icechunk-python/src/config.rs +++ b/icechunk-python/src/config.rs @@ -32,7 +32,7 @@ use pyo3::{ types::{PyAny, PyAnyMethods, PyModule, PyType}, }; -use crate::errors::PyIcechunkStoreError; +use crate::{errors::PyIcechunkStoreError, sync::would_deadlock_current_loop}; #[pyclass(name = "S3StaticCredentials")] #[derive(Clone, Debug)] @@ -176,26 +176,6 @@ fn current_task_locals() -> Option { static FALLBACK_TASK_LOCALS: OnceLock> = OnceLock::new(); -fn object_id(py: Python<'_>, obj: &Bound<'_, PyAny>) -> Result { - let builtins = PyModule::import(py, "builtins")?; - builtins.getattr("id")?.call1((obj,))?.extract() -} - -fn would_deadlock_current_loop( - task_locals: &pyo3_async_runtimes::TaskLocals, -) -> Result { - Python::attach(|py| { - let running_loop = match pyo3_async_runtimes::get_running_loop(py) { - Ok(loop_ref) => loop_ref, - Err(err) if err.is_instance_of::(py) => return Ok(false), - Err(err) => return Err(err), - }; - - let target_loop = task_locals.event_loop(py); - Ok(object_id(py, &running_loop)? == object_id(py, &target_loop)?) - }) -} - fn create_fallback_task_locals() -> Result { let (tx, rx) = std::sync::mpsc::sync_channel::< Result, diff --git a/icechunk-python/src/lib.rs b/icechunk-python/src/lib.rs index c68b8205c..02c4a2d24 100644 --- a/icechunk-python/src/lib.rs +++ b/icechunk-python/src/lib.rs @@ -7,7 +7,7 @@ mod session; mod stats; mod store; mod streams; -mod sync_api; +mod sync; use std::env; diff --git a/icechunk-python/src/repository.rs b/icechunk-python/src/repository.rs index 28a36b9e5..8d3a93c46 100644 --- a/icechunk-python/src/repository.rs +++ b/icechunk-python/src/repository.rs @@ -47,7 +47,7 @@ use crate::{ session::PySession, stats::PyChunkStorageStats, streams::PyAsyncGenerator, - sync_api::ensure_not_running_event_loop, + sync::ensure_not_running_event_loop, }; /// Wrapper needed to implement pyo3 conversion classes @@ -1148,7 +1148,6 @@ impl PyRepository { bytes: Vec, ) -> PyResult { // This is a compute intensive task, we need to release the Gil - ensure_not_running_event_loop(py)?; py.detach(move || { let repository = Repository::from_bytes(bytes) .map_err(PyIcechunkStoreError::RepositoryError)?; @@ -1158,7 +1157,6 @@ impl PyRepository { fn as_bytes(&self, py: Python<'_>) -> PyResult> { // This is a compute intensive task, we need to release the Gil - ensure_not_running_event_loop(py)?; py.detach(move || { let bytes = self .0 @@ -1600,8 +1598,10 @@ impl PyRepository { pub(crate) fn list_manifest_files( &self, + py: Python<'_>, snapshot_id: &str, ) -> PyResult> { + ensure_not_running_event_loop(py)?; let snapshot_id = SnapshotId::try_from(snapshot_id).map_err(|_| { PyIcechunkStoreError::RepositoryError( RepositoryErrorKind::InvalidSnapshotId(snapshot_id.to_owned()).into(), @@ -2361,7 +2361,13 @@ impl PyRepository { } #[pyo3(signature = (snapshot_id, *, pretty = true))] - fn inspect_snapshot(&self, snapshot_id: String, pretty: bool) -> PyResult { + fn inspect_snapshot( + &self, + py: Python<'_>, + snapshot_id: String, + pretty: bool, + ) -> PyResult { + ensure_not_running_event_loop(py)?; let result = pyo3_async_runtimes::tokio::get_runtime() .block_on(async move { let lock = self.0.read().await; @@ -2397,11 +2403,14 @@ impl PyRepository { } #[getter] - fn spec_version(&self) -> u8 { - pyo3_async_runtimes::tokio::get_runtime().block_on(async move { - let repo = self.0.read().await; - repo.spec_version() - }) as u8 + fn spec_version(&self, py: Python<'_>) -> PyResult { + ensure_not_running_event_loop(py)?; + let spec_version = + pyo3_async_runtimes::tokio::get_runtime().block_on(async move { + let repo = self.0.read().await; + repo.spec_version() + }) as u8; + Ok(spec_version) } } diff --git a/icechunk-python/src/session.rs b/icechunk-python/src/session.rs index 21fed6202..6bce22726 100644 --- a/icechunk-python/src/session.rs +++ b/icechunk-python/src/session.rs @@ -21,7 +21,7 @@ use crate::{ repository::{PyDiff, PySnapshotProperties}, store::PyStore, streams::PyAsyncGenerator, - sync_api::ensure_not_running_event_loop, + sync::ensure_not_running_event_loop, }; #[pyclass] @@ -67,7 +67,6 @@ impl PySession { bytes: Vec, ) -> PyResult { // This is a compute intensive task, we need to release the Gil - ensure_not_running_event_loop(py)?; py.detach(move || { let session = Session::from_bytes(bytes).map_err(PyIcechunkStoreError::SessionError)?; @@ -81,7 +80,6 @@ impl PySession { fn as_bytes(&self, py: Python<'_>) -> PyIcechunkStoreResult> { // This is a compute intensive task, we need to release the Gil - ensure_not_running_event_loop(py)?; py.detach(move || { let bytes = self.0.blocking_read().as_bytes().map_err(PyIcechunkStoreError::from)?; @@ -135,7 +133,6 @@ impl PySession { pub fn discard_changes(&self, py: Python<'_>) -> PyResult<()> { // This is blocking function, we need to release the Gil - ensure_not_running_event_loop(py)?; py.detach(move || { self.0 .blocking_write() @@ -199,6 +196,7 @@ impl PySession { array_path: String, shift_chunk: Bound<'py, PyFunction>, ) -> PyResult<()> { + ensure_not_running_event_loop(py)?; let array_path = Path::new(array_path.as_str()) .map_err(|e| StoreError::from(StoreErrorKind::PathError(e))) .map_err(PyIcechunkStoreError::StoreError)?; @@ -232,7 +230,13 @@ impl PySession { }) } - pub fn shift_array(&mut self, array_path: String, offset: Vec) -> PyResult<()> { + pub fn shift_array( + &mut self, + py: Python<'_>, + array_path: String, + offset: Vec, + ) -> PyResult<()> { + ensure_not_running_event_loop(py)?; let array_path = Path::new(array_path.as_str()) .map_err(|e| StoreError::from(StoreErrorKind::PathError(e))) .map_err(PyIcechunkStoreError::StoreError)?; @@ -251,7 +255,6 @@ impl PySession { #[getter] pub fn store(&self, py: Python<'_>) -> PyResult { // This is blocking function, we need to release the Gil - ensure_not_running_event_loop(py)?; py.detach(move || { let session = self.0.blocking_read(); let conc = session.config().get_partial_values_concurrency(); @@ -265,7 +268,6 @@ impl PySession { #[getter] pub fn config(&self, py: Python<'_>) -> PyResult { // This is blocking function, we need to release the Gil - ensure_not_running_event_loop(py)?; py.detach(move || { let session = self.0.blocking_read(); let config = session.config().clone().into(); @@ -363,9 +365,11 @@ impl PySession { pub fn chunk_type( &self, + py: Python<'_>, array_path: String, coords: Vec, ) -> PyResult { + ensure_not_running_event_loop(py)?; let session = self.0.clone(); pyo3_async_runtimes::tokio::get_runtime() .block_on(Self::chunk_type_inner(session, array_path, coords)) diff --git a/icechunk-python/src/sync.rs b/icechunk-python/src/sync.rs new file mode 100644 index 000000000..d347ef83b --- /dev/null +++ b/icechunk-python/src/sync.rs @@ -0,0 +1,35 @@ +use pyo3::{ + Bound, PyErr, PyResult, Python, + exceptions::{PyRuntimeError, PyValueError}, + types::{PyAny, PyAnyMethods, PyModule}, +}; + +fn object_id(py: Python<'_>, obj: &Bound<'_, PyAny>) -> Result { + let builtins = PyModule::import(py, "builtins")?; + builtins.getattr("id")?.call1((obj,))?.extract() +} + +pub(crate) fn would_deadlock_current_loop( + task_locals: &pyo3_async_runtimes::TaskLocals, +) -> Result { + Python::attach(|py| { + let running_loop = match pyo3_async_runtimes::get_running_loop(py) { + Ok(loop_ref) => loop_ref, + Err(err) if err.is_instance_of::(py) => return Ok(false), + Err(err) => return Err(err), + }; + + let target_loop = task_locals.event_loop(py); + Ok(object_id(py, &running_loop)? == object_id(py, &target_loop)?) + }) +} + +pub(crate) fn ensure_not_running_event_loop(py: Python<'_>) -> PyResult<()> { + match pyo3_async_runtimes::get_running_loop(py) { + Ok(_) => Err(PyValueError::new_err( + "deadlock: synchronous API called from a running event loop thread; use the async API or run the sync call in a worker thread", + )), + Err(err) if err.is_instance_of::(py) => Ok(()), + Err(err) => Err(err), + } +} diff --git a/icechunk-python/src/sync_api.rs b/icechunk-python/src/sync_api.rs deleted file mode 100644 index d049b5b96..000000000 --- a/icechunk-python/src/sync_api.rs +++ /dev/null @@ -1,14 +0,0 @@ -use pyo3::{ - PyResult, Python, - exceptions::{PyRuntimeError, PyValueError}, -}; - -pub(crate) fn ensure_not_running_event_loop(py: Python<'_>) -> PyResult<()> { - match pyo3_async_runtimes::get_running_loop(py) { - Ok(_) => Err(PyValueError::new_err( - "deadlock: synchronous API called from a running event loop thread; use the async API or run the sync call in a worker thread", - )), - Err(err) if err.is_instance_of::(py) => Ok(()), - Err(err) => Err(err), - } -} diff --git a/icechunk-python/tests/test_credentials.py b/icechunk-python/tests/test_credentials.py index dee483c6f..821bf97db 100644 --- a/icechunk-python/tests/test_credentials.py +++ b/icechunk-python/tests/test_credentials.py @@ -567,7 +567,7 @@ def run_deadlock_check() -> None: t.start() t.join(timeout=1) assert not t.is_alive(), "Deadlocked: sync call blocked the event loop thread" - assert isinstance(caught_error, IcechunkError | ValueError) + assert isinstance(caught_error, ValueError) assert "deadlock" in str(caught_error) From c600664c97da78126de0fa73c1202dfe24579c41 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Mon, 16 Feb 2026 15:16:51 -0500 Subject: [PATCH 11/14] lint --- icechunk-python/src/config.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/icechunk-python/src/config.rs b/icechunk-python/src/config.rs index 1e47770bf..a1b1af8cd 100644 --- a/icechunk-python/src/config.rs +++ b/icechunk-python/src/config.rs @@ -191,7 +191,13 @@ fn create_fallback_task_locals() -> Result Date: Mon, 16 Feb 2026 15:22:02 -0500 Subject: [PATCH 12/14] separate out asyncio bridge --- icechunk-python/src/asyncio_bridge.rs | 64 +++++++++++++++++++++++++ icechunk-python/src/config.rs | 68 +++------------------------ icechunk-python/src/lib.rs | 1 + 3 files changed, 72 insertions(+), 61 deletions(-) create mode 100644 icechunk-python/src/asyncio_bridge.rs diff --git a/icechunk-python/src/asyncio_bridge.rs b/icechunk-python/src/asyncio_bridge.rs new file mode 100644 index 000000000..99e596386 --- /dev/null +++ b/icechunk-python/src/asyncio_bridge.rs @@ -0,0 +1,64 @@ +use pyo3::{ + PyErr, PyResult, Python, + exceptions::PyRuntimeError, + types::{PyAnyMethods, PyModule}, +}; +use std::sync::OnceLock; + +pub(crate) fn current_task_locals() -> Option { + Python::attach(|py| pyo3_async_runtimes::tokio::get_current_locals(py).ok()) +} + +static FALLBACK_TASK_LOCALS: OnceLock> = + OnceLock::new(); + +fn create_fallback_task_locals() -> Result { + let (tx, rx) = std::sync::mpsc::sync_channel::< + Result, + >(1); + + std::thread::Builder::new() + .name("icechunk-python-asyncio-fallback".to_string()) + .spawn(move || { + let mut tx = Some(tx); + let run_result = Python::attach(|py| -> PyResult<()> { + let asyncio = PyModule::import(py, "asyncio")?; + let event_loop = asyncio.getattr("new_event_loop")?.call0()?; + asyncio.getattr("set_event_loop")?.call1((event_loop.clone(),))?; + + let locals = pyo3_async_runtimes::TaskLocals::new(event_loop.clone()); + if let Some(sender) = tx.take() { + let _ = sender.send(Ok(locals)); + } else { + return Err(PyRuntimeError::new_err( + "fallback asyncio loop sender was unexpectedly missing", + )); + } + + event_loop.call_method0("run_forever")?; + Ok(()) + }); + + if let Err(err) = run_result + && let Some(sender) = tx + { + let _ = sender.send(Err(err.to_string())); + } + }) + .map_err(|err| format!("failed to start fallback asyncio loop thread: {err}"))?; + + match rx.recv() { + Ok(Ok(locals)) => Ok(locals), + Ok(Err(err)) => Err(format!("failed to initialize fallback asyncio loop: {err}")), + Err(err) => Err(format!( + "fallback asyncio loop thread terminated before initialization: {err}" + )), + } +} + +pub(crate) fn fallback_task_locals() -> Result { + match FALLBACK_TASK_LOCALS.get_or_init(create_fallback_task_locals) { + Ok(locals) => Ok(locals.clone()), + Err(err) => Err(PyRuntimeError::new_err(err.clone())), + } +} diff --git a/icechunk-python/src/config.rs b/icechunk-python/src/config.rs index a1b1af8cd..929d4680e 100644 --- a/icechunk-python/src/config.rs +++ b/icechunk-python/src/config.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use chrono::{DateTime, Datelike, TimeDelta, Timelike, Utc}; use icechunk::storage::RetriesSettings; use itertools::Itertools; -use pyo3::exceptions::{PyRuntimeError, PyValueError}; +use pyo3::exceptions::PyValueError; use serde::{Deserialize, Serialize}; use std::hash::{Hash, Hasher}; use std::{ @@ -11,7 +11,7 @@ use std::{ hash::DefaultHasher, num::{NonZeroU16, NonZeroU64}, path::PathBuf, - sync::{Arc, OnceLock}, + sync::Arc, }; use icechunk::{ @@ -32,7 +32,11 @@ use pyo3::{ types::{PyAny, PyAnyMethods, PyModule, PyType}, }; -use crate::{errors::PyIcechunkStoreError, sync::would_deadlock_current_loop}; +use crate::{ + asyncio_bridge::{current_task_locals, fallback_task_locals}, + errors::PyIcechunkStoreError, + sync::would_deadlock_current_loop, +}; #[pyclass(name = "S3StaticCredentials")] #[derive(Clone, Debug)] @@ -169,64 +173,6 @@ impl PythonCredentialsFetcher { } } -fn current_task_locals() -> Option { - Python::attach(|py| pyo3_async_runtimes::tokio::get_current_locals(py).ok()) -} - -static FALLBACK_TASK_LOCALS: OnceLock> = - OnceLock::new(); - -fn create_fallback_task_locals() -> Result { - let (tx, rx) = std::sync::mpsc::sync_channel::< - Result, - >(1); - - std::thread::Builder::new() - .name("icechunk-python-asyncio-fallback".to_string()) - .spawn(move || { - let mut tx = Some(tx); - let run_result = Python::attach(|py| -> PyResult<()> { - let asyncio = PyModule::import(py, "asyncio")?; - let event_loop = asyncio.getattr("new_event_loop")?.call0()?; - asyncio.getattr("set_event_loop")?.call1((event_loop.clone(),))?; - - let locals = pyo3_async_runtimes::TaskLocals::new(event_loop.clone()); - if let Some(sender) = tx.take() { - let _ = sender.send(Ok(locals)); - } else { - return Err(PyRuntimeError::new_err( - "fallback asyncio loop sender was unexpectedly missing", - )); - } - - event_loop.call_method0("run_forever")?; - Ok(()) - }); - - if let Err(err) = run_result - && let Some(sender) = tx - { - let _ = sender.send(Err(err.to_string())); - } - }) - .map_err(|err| format!("failed to start fallback asyncio loop thread: {err}"))?; - - match rx.recv() { - Ok(Ok(locals)) => Ok(locals), - Ok(Err(err)) => Err(format!("failed to initialize fallback asyncio loop: {err}")), - Err(err) => Err(format!( - "fallback asyncio loop thread terminated before initialization: {err}" - )), - } -} - -fn fallback_task_locals() -> Result { - match FALLBACK_TASK_LOCALS.get_or_init(create_fallback_task_locals) { - Ok(locals) => Ok(locals.clone()), - Err(err) => Err(PyRuntimeError::new_err(err.clone())), - } -} - enum PythonCallbackResult { Value(PyCred), Awaitable(Py), diff --git a/icechunk-python/src/lib.rs b/icechunk-python/src/lib.rs index 02c4a2d24..e14aa3f0d 100644 --- a/icechunk-python/src/lib.rs +++ b/icechunk-python/src/lib.rs @@ -1,3 +1,4 @@ +mod asyncio_bridge; mod config; mod conflicts; mod errors; From 4c5225a5c454d5101a873218161503fd09b052e8 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Mon, 16 Feb 2026 15:25:43 -0500 Subject: [PATCH 13/14] mypy --- icechunk-python/tests/test_credentials.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/icechunk-python/tests/test_credentials.py b/icechunk-python/tests/test_credentials.py index 821bf97db..d5776bb96 100644 --- a/icechunk-python/tests/test_credentials.py +++ b/icechunk-python/tests/test_credentials.py @@ -3,6 +3,7 @@ import pickle import threading import time +from collections.abc import Iterator from datetime import UTC, datetime from pathlib import Path @@ -420,7 +421,7 @@ class BadException(Exception): @pytest.fixture(autouse=False) -def reset_cred_tracker(): +def reset_cred_tracker() -> Iterator[list[int]]: _cred_refresh_loop_ids.clear() yield _cred_refresh_loop_ids _cred_refresh_loop_ids.clear() @@ -536,7 +537,7 @@ def sync_in_thread() -> None: asyncio.run(run()) -def test_async_cred_refresh_graceful_deadlock(): +def test_async_cred_refresh_graceful_deadlock() -> None: # Deadlock scenario: sync list_branches called directly on the event loop # thread where the credential callback would also need to run. Use a thread # with a timeout so the test doesn't hang forever if it actually deadlocks. @@ -599,6 +600,6 @@ def on_another_thread(r: Repository) -> None: t.join() assert len(_cred_refresh_loop_ids) > 4 assert len(set(_cred_refresh_loop_ids)) == 1 - assert len(set(_cred_refresh_loop_ids)) == 1, ( - "All credential refresh calls should be on the same loop even if there isn't an event loop at the start" - ) + assert ( + len(set(_cred_refresh_loop_ids)) == 1 + ), "All credential refresh calls should be on the same loop even if there isn't an event loop at the start" From a99cdde20a59f3afeda49c8098dca915d904b6ec Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Mon, 16 Feb 2026 16:41:35 -0500 Subject: [PATCH 14/14] Use async when we should from tests --- icechunk-python/src/sync.rs | 2 +- icechunk-python/tests/conftest.py | 7 + icechunk-python/tests/test_can_read_old.py | 101 ++++++++------ icechunk-python/tests/test_concurrency.py | 10 +- .../tests/test_distributed_writers.py | 23 +-- icechunk-python/tests/test_gc.py | 23 +-- icechunk-python/tests/test_inspect.py | 3 +- icechunk-python/tests/test_move.py | 20 +-- icechunk-python/tests/test_regressions.py | 34 ++--- icechunk-python/tests/test_session.py | 36 ++--- icechunk-python/tests/test_shift.py | 35 ++--- icechunk-python/tests/test_store.py | 59 ++++---- icechunk-python/tests/test_timetravel.py | 131 +++++++++--------- icechunk-python/tests/test_virtual_ref.py | 12 +- .../tests/test_zarr/test_store/test_core.py | 6 +- .../test_store/test_icechunk_store.py | 22 +-- 16 files changed, 282 insertions(+), 242 deletions(-) diff --git a/icechunk-python/src/sync.rs b/icechunk-python/src/sync.rs index d347ef83b..0c79506bf 100644 --- a/icechunk-python/src/sync.rs +++ b/icechunk-python/src/sync.rs @@ -27,7 +27,7 @@ pub(crate) fn would_deadlock_current_loop( pub(crate) fn ensure_not_running_event_loop(py: Python<'_>) -> PyResult<()> { match pyo3_async_runtimes::get_running_loop(py) { Ok(_) => Err(PyValueError::new_err( - "deadlock: synchronous API called from a running event loop thread; use the async API or run the sync call in a worker thread", + "synchronous API called from a running event loop thread (deadlock risk); this usage is disallowed. Use the async API or run the synchronous call in a worker thread", )), Err(err) if err.is_instance_of::(py) => Ok(()), Err(err) => Err(err), diff --git a/icechunk-python/tests/conftest.py b/icechunk-python/tests/conftest.py index 87dbbf4bc..cd83c6712 100644 --- a/icechunk-python/tests/conftest.py +++ b/icechunk-python/tests/conftest.py @@ -1,3 +1,4 @@ +import asyncio from typing import Literal, cast import boto3 @@ -22,6 +23,12 @@ def parse_repo( ) +async def parse_repo_async( + store: Literal["local", "memory"], path: str, spec_version: int | None +) -> Repository: + return await asyncio.to_thread(parse_repo, store, path, spec_version) + + @pytest.fixture(scope="function") def repo( request: pytest.FixtureRequest, tmpdir: str, any_spec_version: int | None diff --git a/icechunk-python/tests/test_can_read_old.py b/icechunk-python/tests/test_can_read_old.py index f0f797889..4b36edba3 100644 --- a/icechunk-python/tests/test_can_read_old.py +++ b/icechunk-python/tests/test_can_read_old.py @@ -11,6 +11,7 @@ file as a python script: `python ./tests/test_can_read_old.py`. """ +import asyncio import shutil from datetime import UTC, datetime from typing import Any, cast @@ -86,8 +87,10 @@ async def write_a_split_repo(path: str) -> None: ) print(f"Writing repository to {store_path}") - repo = mk_repo(create=True, store_path=store_path, config=config) - session = repo.writable_session("main") + repo = await asyncio.to_thread( + mk_repo, create=True, store_path=store_path, config=config + ) + session = await repo.writable_session_async("main") store = session.store root = zarr.group(store=store) @@ -115,47 +118,49 @@ async def write_a_split_repo(path: str) -> None: fill_value=8, attributes={"this": "is a nice array", "icechunk": 1, "size": 42.0}, ) - session.commit("empty structure") + await session.commit_async("empty structure") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") big_chunks = zarr.open_array(session.store, path="/group1/split", mode="a") small_chunks = zarr.open_array(session.store, path="/group1/small_chunks", mode="a") big_chunks[:] = 120 small_chunks[:] = 0 - session.commit("write data") + await session.commit_async("write data") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") big_chunks = zarr.open_array(session.store, path="group1/split", mode="a") small_chunks = zarr.open_array(session.store, path="group1/small_chunks", mode="a") big_chunks[:] = 12 small_chunks[:] = 1 - session.commit("write data again") + await session.commit_async("write data again") ### new config config = ic.RepositoryConfig.default() config.inline_chunk_threshold_bytes = 12 config.manifest = ic.ManifestConfig(splitting=UPDATED_SPLITTING_CONFIG) - repo = mk_repo(create=False, store_path=store_path, config=config) - repo.save_config() - session = repo.writable_session("main") + repo = await asyncio.to_thread( + mk_repo, create=False, store_path=store_path, config=config + ) + await repo.save_config_async() + session = await repo.writable_session_async("main") big_chunks = zarr.open_array(session.store, path="group1/split", mode="a") small_chunks = zarr.open_array(session.store, path="group1/small_chunks", mode="a") big_chunks[:] = 14 small_chunks[:] = 3 - session.commit("write data again with more splits") + await session.commit_async("write data again with more splits") async def do_icechunk_can_read_old_repo_with_manifest_splitting(path: str) -> None: - repo = mk_repo(create=False, store_path=path) - ancestry = list(repo.ancestry(branch="main"))[::-1] + repo = await asyncio.to_thread(mk_repo, create=False, store_path=path) + ancestry = (await asyncio.to_thread(lambda: list(repo.ancestry(branch="main"))))[::-1] init_snapshot = ancestry[1] assert init_snapshot.message == "empty structure" - assert len(repo.list_manifest_files(init_snapshot.id)) == 0 + assert len(await repo.list_manifest_files_async(init_snapshot.id)) == 0 snapshot = ancestry[2] assert snapshot.message == "write data" - assert len(repo.list_manifest_files(snapshot.id)) == 9 + assert len(await repo.list_manifest_files_async(snapshot.id)) == 9 snapshot = ancestry[3] assert snapshot.message == "write data again" @@ -180,8 +185,8 @@ async def write_a_test_repo(path: str) -> None: """ print(f"Writing repository to {path}") - repo = mk_repo(create=True, store_path=path) - session = repo.writable_session("main") + repo = await asyncio.to_thread(mk_repo, create=True, store_path=path) + session = await repo.writable_session_async("main") store = session.store root = zarr.group(store=store) @@ -208,9 +213,9 @@ async def write_a_test_repo(path: str) -> None: fill_value=8, attributes={"this": "is a nice array", "icechunk": 1, "size": 42.0}, ) - session.commit("empty structure") + await session.commit_async("empty structure") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store root = zarr.group(store=store) big_chunks = cast("zarr.Array[Any]", root["group1/big_chunks"]) @@ -218,8 +223,8 @@ async def write_a_test_repo(path: str) -> None: big_chunks[:] = 42.0 small_chunks[:] = 84 - snapshot = session.commit("fill data") - session = repo.writable_session("main") + snapshot = await session.commit_async("fill data") + session = await repo.writable_session_async("main") store = session.store # We are going to write this chunk to storage as a virtual chunk @@ -236,22 +241,22 @@ async def write_a_test_repo(path: str) -> None: length=virtual_chunk_data_size, checksum=datetime(9999, 12, 31, tzinfo=UTC), ) - snapshot = session.commit("set virtual chunk") - session = repo.writable_session("main") + snapshot = await session.commit_async("set virtual chunk") + session = await repo.writable_session_async("main") store = session.store - repo.create_branch("my-branch", snapshot_id=snapshot) - session = repo.writable_session("my-branch") + await repo.create_branch_async("my-branch", snapshot_id=snapshot) + session = await repo.writable_session_async("my-branch") store = session.store await store.delete("group1/small_chunks/c/4") - snap4 = session.commit("delete a chunk") + snap4 = await session.commit_async("delete a chunk") - repo.create_tag("it works!", snapshot_id=snap4) - repo.create_tag("deleted", snapshot_id=snap4) - repo.delete_tag("deleted") + await repo.create_tag_async("it works!", snapshot_id=snap4) + await repo.create_tag_async("deleted", snapshot_id=snap4) + await repo.delete_tag_async("deleted") - session = repo.writable_session("my-branch") + session = await repo.writable_session_async("my-branch") store = session.store root = zarr.open_group(store=store) @@ -275,9 +280,9 @@ async def write_a_test_repo(path: str) -> None: fill_value=float("nan"), attributes={"this": "is a nice array", "icechunk": 1, "size": 42.0}, ) - snap5 = session.commit("some more structure") + snap5 = await session.commit_async("some more structure") - repo.create_tag("it also works!", snapshot_id=snap5) + await repo.create_tag_async("it also works!", snapshot_id=snap5) store.close() @@ -286,8 +291,8 @@ async def do_icechunk_can_read_old_repo(path: str) -> None: # we import here so it works when the script is ran by pytest from tests.conftest import write_chunks_to_minio - repo = mk_repo(create=False, store_path=path) - main_snapshot = repo.lookup_branch("main") + repo = await asyncio.to_thread(mk_repo, create=False, store_path=path) + main_snapshot = await repo.lookup_branch_async("main") expected_main_history = [ "set virtual chunk", @@ -296,7 +301,10 @@ async def do_icechunk_can_read_old_repo(path: str) -> None: "Repository initialized", ] assert [ - p.message for p in repo.ancestry(snapshot_id=main_snapshot) + p.message + for p in await asyncio.to_thread( + lambda: list(repo.ancestry(snapshot_id=main_snapshot)) + ) ] == expected_main_history expected_branch_history = [ @@ -305,21 +313,26 @@ async def do_icechunk_can_read_old_repo(path: str) -> None: ] + expected_main_history assert [ - p.message for p in repo.ancestry(branch="my-branch") + p.message + for p in await asyncio.to_thread(lambda: list(repo.ancestry(branch="my-branch"))) ] == expected_branch_history assert [ - p.message for p in repo.ancestry(tag="it also works!") + p.message + for p in await asyncio.to_thread( + lambda: list(repo.ancestry(tag="it also works!")) + ) ] == expected_branch_history - assert [p.message for p in repo.ancestry(tag="it works!")] == expected_branch_history[ - 1: - ] + assert [ + p.message + for p in await asyncio.to_thread(lambda: list(repo.ancestry(tag="it works!"))) + ] == expected_branch_history[1:] with pytest.raises(ic.IcechunkError, match="ref not found"): - repo.readonly_session(tag="deleted") + await repo.readonly_session_async(tag="deleted") - session = repo.writable_session("my-branch") + session = await repo.writable_session_async("my-branch") store = session.store assert sorted([p async for p in store.list_dir("")]) == [ "group1", @@ -374,8 +387,8 @@ async def do_icechunk_can_read_old_repo(path: str) -> None: big_chunks = cast("zarr.Array[Any]", root["group1/big_chunks"]) assert_array_equal(big_chunks[:], 42.0) - parents = list(repo.ancestry(branch="main")) - diff = repo.diff(to_branch="main", from_snapshot_id=parents[-2].id) + parents = await asyncio.to_thread(lambda: list(repo.ancestry(branch="main"))) + diff = await repo.diff_async(to_branch="main", from_snapshot_id=parents[-2].id) assert diff.new_groups == set() assert diff.new_arrays == set() assert set(diff.updated_chunks.keys()) == { diff --git a/icechunk-python/tests/test_concurrency.py b/icechunk-python/tests/test_concurrency.py index ca0261f6c..d2c4f1de4 100644 --- a/icechunk-python/tests/test_concurrency.py +++ b/icechunk-python/tests/test_concurrency.py @@ -52,12 +52,12 @@ async def list_store(store: icechunk.IcechunkStore, barrier: asyncio.Barrier) -> async def test_concurrency(any_spec_version: int | None) -> None: - repo = icechunk.Repository.open_or_create( + repo = await icechunk.Repository.open_or_create_async( storage=icechunk.in_memory_storage(), create_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store group = zarr.group(store=store, overwrite=True) @@ -85,7 +85,7 @@ async def test_concurrency(any_spec_version: int | None) -> None: all_coords = {coords async for coords in session.chunk_coordinates("/array")} assert all_coords == {(x, y) for x in range(N) for y in range(N - 1)} - _res = session.commit("commit") + _res = await session.commit_async("commit") assert isinstance(group["array"], zarr.Array) array = group["array"] @@ -150,14 +150,14 @@ async def test_thread_concurrency(any_spec_version: int | None) -> None: ) # Open the store - repo = icechunk.Repository.create( + repo = await icechunk.Repository.create_async( storage=storage, config=config, authorize_virtual_chunk_access=credentials, spec_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store group = zarr.group(store=store, overwrite=True) diff --git a/icechunk-python/tests/test_distributed_writers.py b/icechunk-python/tests/test_distributed_writers.py index a856e1258..3f8191e64 100644 --- a/icechunk-python/tests/test_distributed_writers.py +++ b/icechunk-python/tests/test_distributed_writers.py @@ -1,3 +1,4 @@ +import asyncio import time import warnings from typing import Any, cast @@ -70,8 +71,8 @@ async def test_distributed_writers( does a distributed commit. When done, we open the store again and verify we can write everything we have written. """ - repo = mk_repo(any_spec_version, use_object_store) - session = repo.writable_session(branch="main") + repo = await asyncio.to_thread(mk_repo, any_spec_version, use_object_store) + session = await repo.writable_session_async(branch="main") store = session.store shape = (CHUNKS_PER_DIM * CHUNK_DIM_SIZE,) * 2 @@ -86,22 +87,22 @@ async def test_distributed_writers( dtype="f8", fill_value=float("nan"), ) - first_snap = session.commit("array created") + first_snap = await session.commit_async("array created") - def do_writes(branch_name: str) -> None: - repo.create_branch(branch_name, first_snap) - session = repo.writable_session(branch=branch_name) + async def do_writes(branch_name: str) -> None: + await repo.create_branch_async(branch_name, first_snap) + session = await repo.writable_session_async(branch=branch_name) fork = session.fork() group = zarr.open_group(store=fork.store) zarray = cast("zarr.Array[Any]", group["array"]) merged_session = store_dask(sources=[dask_array], targets=[zarray]) - session.merge(merged_session) - commit_res = session.commit("distributed commit") + await session.merge_async(merged_session) + commit_res = await session.commit_async("distributed commit") assert commit_res async def verify(branch_name: str) -> None: # Lets open a new store to verify the results - readonly_session = repo.readonly_session(branch=branch_name) + readonly_session = await repo.readonly_session_async(branch=branch_name) store = readonly_session.store all_keys = [key async for key in store.list_prefix("/")] assert ( @@ -116,9 +117,9 @@ async def verify(branch_name: str) -> None: assert_eq(roundtripped, dask_array) # type: ignore [no-untyped-call] with Client(dashboard_address=":0"): # type: ignore[no-untyped-call] - do_writes("with-processes") + await do_writes("with-processes") await verify("with-processes") with dask.config.set(scheduler="threads"): - do_writes("with-threads") + await do_writes("with-threads") await verify("with-threads") diff --git a/icechunk-python/tests/test_gc.py b/icechunk-python/tests/test_gc.py index 46acf5b70..1b6a0b153 100644 --- a/icechunk-python/tests/test_gc.py +++ b/icechunk-python/tests/test_gc.py @@ -1,3 +1,4 @@ +import asyncio import time from datetime import UTC, datetime from typing import Any, cast @@ -31,9 +32,9 @@ def mk_repo(spec_version: int | None) -> tuple[str, ic.Repository]: @pytest.mark.filterwarnings("ignore:datetime.datetime.utcnow") @pytest.mark.parametrize("use_async", [True, False]) async def test_expire_and_gc(use_async: bool, any_spec_version: int | None) -> None: - prefix, repo = mk_repo(any_spec_version) + prefix, repo = await asyncio.to_thread(mk_repo, any_spec_version) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store group = zarr.group(store=store, overwrite=True) @@ -44,24 +45,24 @@ async def test_expire_and_gc(use_async: bool, any_spec_version: int | None) -> N dtype="i4", fill_value=-1, ) - session.commit("array created") + await session.commit_async("array created") for i in range(20): - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store group = zarr.open_group(store=store) array = cast("zarr.core.array.Array[Any]", group["array"]) array[i] = i - session.commit(f"written coord {i}") + await session.commit_async(f"written coord {i}") old = datetime.now(UTC) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store group = zarr.open_group(store=store) array = cast("zarr.core.array.Array[Any]", group["array"]) array[999] = 0 - session.commit("written coord 999") + await session.commit_async("written coord 999") client = get_minio_client() @@ -105,7 +106,7 @@ async def test_expire_and_gc(use_async: bool, any_spec_version: int | None) -> N if use_async: expired_snapshots = await repo.expire_snapshots_async(old) else: - expired_snapshots = repo.expire_snapshots(old) + expired_snapshots = await asyncio.to_thread(repo.expire_snapshots, old) # empty array + 20 old versions assert len(expired_snapshots) == 21 @@ -135,7 +136,7 @@ def space_used() -> int: if use_async: gc_result = await repo.garbage_collect_async(old, dry_run=True) else: - gc_result = repo.garbage_collect(old, dry_run=True) + gc_result = await asyncio.to_thread(repo.garbage_collect, old, dry_run=True) space_after = space_used() assert space_before == space_after @@ -154,7 +155,7 @@ def space_used() -> int: if use_async: gc_result = await repo.garbage_collect_async(old) else: - gc_result = repo.garbage_collect(old) + gc_result = await asyncio.to_thread(repo.garbage_collect, old) space_after = space_used() @@ -204,7 +205,7 @@ def space_used() -> int: ) # we can still read the array - session = repo.readonly_session(branch="main") + session = await repo.readonly_session_async(branch="main") store = session.store group = zarr.open_group(store=store, mode="r") array = cast("zarr.core.array.Array[Any]", group["array"]) diff --git a/icechunk-python/tests/test_inspect.py b/icechunk-python/tests/test_inspect.py index a04ebc325..d4ea34a6d 100644 --- a/icechunk-python/tests/test_inspect.py +++ b/icechunk-python/tests/test_inspect.py @@ -1,3 +1,4 @@ +import asyncio import json import icechunk as ic @@ -23,7 +24,7 @@ async def test_inspect_snapshot_async() -> None: repo = await ic.Repository.open_async( storage=ic.local_filesystem_storage("./tests/data/split-repo-v2") ) - snap = next(repo.ancestry(branch="main")).id + snap = (await asyncio.to_thread(lambda: next(repo.ancestry(branch="main")))).id pretty_str = await repo.inspect_snapshot_async(snap, pretty=True) non_pretty_str = await repo.inspect_snapshot_async(snap, pretty=False) diff --git a/icechunk-python/tests/test_move.py b/icechunk-python/tests/test_move.py index a536eace5..045a79847 100644 --- a/icechunk-python/tests/test_move.py +++ b/icechunk-python/tests/test_move.py @@ -1,3 +1,5 @@ +import asyncio + import numpy.testing import pytest @@ -6,10 +8,10 @@ async def test_basic_move() -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store root = zarr.group(store=store, overwrite=True) group = root.create_group("my/old/path", overwrite=True) @@ -24,11 +26,11 @@ async def test_basic_move() -> None: "my/old/path/array/zarr.json", ] ) - session.commit("create array") + await session.commit_async("create array") - session = repo.rearrange_session("main") + session = await repo.rearrange_session_async("main") store = session.store - session.move("/my/old", "/my/new") + await session.move_async("/my/old", "/my/new") all_keys = sorted([k async for k in store.list()]) assert all_keys == sorted( [ @@ -39,16 +41,16 @@ async def test_basic_move() -> None: "my/new/path/array/zarr.json", ] ) - session.commit("directory renamed") + await session.commit_async("directory renamed") - session = repo.readonly_session("main") + session = await repo.readonly_session_async("main") store = session.store group = zarr.open_group(store=store, mode="r") array = group["my/new/path/array"] numpy.testing.assert_array_equal(array, 42) - a, b, *_ = repo.ancestry(branch="main") - diff = repo.diff(from_snapshot_id=b.id, to_snapshot_id=a.id) + a, b, *_ = await asyncio.to_thread(lambda: list(repo.ancestry(branch="main"))) + diff = await repo.diff_async(from_snapshot_id=b.id, to_snapshot_id=a.id) assert diff.moved_nodes == [("/my/old", "/my/new")] assert ( repr(diff) diff --git a/icechunk-python/tests/test_regressions.py b/icechunk-python/tests/test_regressions.py index 494556fcc..33027e8bd 100644 --- a/icechunk-python/tests/test_regressions.py +++ b/icechunk-python/tests/test_regressions.py @@ -53,13 +53,13 @@ async def test_issue_418(any_spec_version: int | None) -> None: } ) - repo = Repository.create( + repo = await Repository.create_async( storage=in_memory_storage(), config=config, authorize_virtual_chunk_access=credentials, spec_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store root = zarr.Group.from_store(store=store, zarr_format=3) @@ -82,9 +82,9 @@ async def test_issue_418(any_spec_version: int | None) -> None: assert (await store._store.get("time/c/0")) == b"firs" # codespell:ignore firs assert (await store._store.get("time/c/1")) == b"econ" - session.commit("Initial commit") + await session.commit_async("Initial commit") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store root = zarr.Group.open(store=store) @@ -104,7 +104,7 @@ async def test_issue_418(any_spec_version: int | None) -> None: assert (await store._store.get("time/c/2")) == b"thir" # commit - session.commit("Append virtual ref") + await session.commit_async("Append virtual ref") assert (await store._store.get("lon/c/0")) == b"fift" assert (await store._store.get("time/c/0")) == b"firs" # codespell:ignore firs @@ -114,11 +114,11 @@ async def test_issue_418(any_spec_version: int | None) -> None: async def test_read_chunks_from_old_array(any_spec_version: int | None) -> None: # This regression appeared during the change to manifest per array - repo = Repository.create( + repo = await Repository.create_async( storage=in_memory_storage(), spec_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store root = zarr.Group.from_store(store=store, zarr_format=3) @@ -127,9 +127,9 @@ async def test_read_chunks_from_old_array(any_spec_version: int | None) -> None: array1[:] = 42 assert array1[0] == 42 print("about to commit 1") - session.commit("create array 1") + await session.commit_async("create array 1") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store root = zarr.group(store=store, zarr_format=3) # array1 = root.require_array(name="array1", shape=((2,)), chunks=((1,)), dtype="i4") @@ -137,9 +137,9 @@ async def test_read_chunks_from_old_array(any_spec_version: int | None) -> None: # assert array1[0] == 42 array2[:] = 84 print("about to commit 2") - session.commit("create array 2") + await session.commit_async("create array 2") - session = repo.readonly_session(branch="main") + session = await repo.readonly_session_async(branch="main") store = session.store root = zarr.Group.open(store=store, zarr_format=3) array1 = root.require_array(name="array1", shape=((2,)), chunks=((1,)), dtype="i4") @@ -149,11 +149,11 @@ async def test_read_chunks_from_old_array(any_spec_version: int | None) -> None: async def test_tag_with_open_session(any_spec_version: int | None) -> None: """This is an issue found by hypothesis""" - repo = Repository.create( + repo = await Repository.create_async( storage=in_memory_storage(), spec_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store await store.set( @@ -168,8 +168,8 @@ async def test_tag_with_open_session(any_spec_version: int | None) -> None: b'{\n "shape": [\n 1\n ],\n "data_type": "bool",\n "chunk_grid": {\n "name": "regular",\n "configuration": {\n "chunk_shape": [\n 1\n ]\n }\n },\n "chunk_key_encoding": {\n "name": "default",\n "configuration": {\n "separator": "/"\n }\n },\n "fill_value": false,\n "codecs": [\n {\n "name": "bytes"\n }\n ],\n "attributes": {},\n "zarr_format": 3,\n "node_type": "array",\n "storage_transformers": []\n}' ), ) - session.commit("") - session = repo.writable_session("main") + await session.commit_async("") + session = await repo.writable_session_async("main") store = session.store await store.set( @@ -193,8 +193,8 @@ async def test_tag_with_open_session(any_spec_version: int | None) -> None: ), ) - session.commit("") - session = repo.writable_session("main") + await session.commit_async("") + session = await repo.writable_session_async("main") store = session.store async for k in store.list_prefix(""): diff --git a/icechunk-python/tests/test_session.py b/icechunk-python/tests/test_session.py index a4fc03ced..bf83c5d15 100644 --- a/icechunk-python/tests/test_session.py +++ b/icechunk-python/tests/test_session.py @@ -24,11 +24,11 @@ @pytest.mark.parametrize("use_async", [True, False]) async def test_session_fork(use_async: bool, any_spec_version: int | None) -> None: with tempfile.TemporaryDirectory() as tmpdir: - repo = Repository.create( + repo = await Repository.create_async( local_filesystem_storage(tmpdir), spec_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") zarr.group(session.store) assert session.has_uncommitted_changes @@ -38,7 +38,7 @@ async def test_session_fork(use_async: bool, any_spec_version: int | None) -> No if use_async: await session.commit_async("init") else: - session.commit("init") + await session.commit_async("init") # forking a read-only session with pytest.raises(ValueError): @@ -48,29 +48,29 @@ async def test_session_fork(use_async: bool, any_spec_version: int | None) -> No pickle.loads(pickle.dumps(session.fork())) pickle.loads(pickle.dumps(session)) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") fork = pickle.loads(pickle.dumps(session.fork())) zarr.create_group(fork.store, path="/foo") assert not session.has_uncommitted_changes assert fork.has_uncommitted_changes with pytest.warns(UserWarning): with pytest.raises(IcechunkError, match="cannot commit"): - session.commit("foo") + await session.commit_async("foo") if use_async: await session.merge_async(fork) await session.commit_async("foo") else: - session.merge(fork) - session.commit("foo") + await session.merge_async(fork) + await session.commit_async("foo") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") fork1 = pickle.loads(pickle.dumps(session.fork())) fork2 = pickle.loads(pickle.dumps(session.fork())) zarr.create_group(fork1.store, path="/foo1") zarr.create_group(fork2.store, path="/foo2") with pytest.raises(TypeError, match="Cannot commit a fork"): - fork1.commit("foo") + await fork1.commit_async("foo") fork1 = pickle.loads(pickle.dumps(fork1)) fork2 = pickle.loads(pickle.dumps(fork2)) @@ -83,8 +83,8 @@ async def test_session_fork(use_async: bool, any_spec_version: int | None) -> No await session.merge_async(fork1, fork2) await session.commit_async("all done") else: - session.merge(fork1, fork2) - session.commit("all done") + await session.merge_async(fork1, fork2) + await session.commit_async("all done") groups = set( name for name, _ in zarr.open_group(session.store, mode="r").groups() @@ -92,7 +92,7 @@ async def test_session_fork(use_async: bool, any_spec_version: int | None) -> No assert groups == {"foo", "foo1", "foo2"} # forking a forked session may be useful - session = repo.writable_session("main") + session = await repo.writable_session_async("main") fork1 = pickle.loads(pickle.dumps(session.fork())) fork2 = pickle.loads(pickle.dumps(fork1.fork())) zarr.create_group(fork1.store, path="/foo3") @@ -102,7 +102,7 @@ async def test_session_fork(use_async: bool, any_spec_version: int | None) -> No fork1 = pickle.loads(pickle.dumps(fork1)) fork2 = pickle.loads(pickle.dumps(fork2)) - session.merge(fork1, fork2) + await session.merge_async(fork1, fork2) groups = set( name for name, _ in zarr.open_group(session.store, mode="r").groups() @@ -124,13 +124,13 @@ async def test_chunk_type( container = VirtualChunkContainer("file:///foo/", store_config) config.set_virtual_chunk_container(container) - repo = Repository.create( + repo = await Repository.create_async( storage=in_memory_storage(), config=config, spec_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store group = zarr.group(store=store, overwrite=True) air_temp = group.create_array("air_temp", shape=(1, 4), chunks=(1, 1), dtype="i4") @@ -154,9 +154,9 @@ async def test_chunk_type( air_temp[0, 2] = 42 assert air_temp[0, 2] == 42 - assert session.chunk_type("/air_temp", [0, 0]) == ChunkType.VIRTUAL - assert session.chunk_type("/air_temp", [0, 2]) == chunk_type - assert session.chunk_type("/air_temp", [0, 3]) == ChunkType.UNINITIALIZED + assert await session.chunk_type_async("/air_temp", [0, 0]) == ChunkType.VIRTUAL + assert await session.chunk_type_async("/air_temp", [0, 2]) == chunk_type + assert await session.chunk_type_async("/air_temp", [0, 3]) == ChunkType.UNINITIALIZED assert await session.chunk_type_async("/air_temp", [0, 0]) == ChunkType.VIRTUAL assert await session.chunk_type_async("/air_temp", [0, 2]) == chunk_type diff --git a/icechunk-python/tests/test_shift.py b/icechunk-python/tests/test_shift.py index 1f9b09bb1..017472226 100644 --- a/icechunk-python/tests/test_shift.py +++ b/icechunk-python/tests/test_shift.py @@ -1,3 +1,4 @@ +import asyncio from collections.abc import Iterable from typing import Any, cast @@ -8,18 +9,18 @@ async def test_shift_using_function() -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") root = zarr.group(store=session.store, overwrite=True) array = root.create_array( "array", shape=(50,), chunks=(2,), dtype="i4", fill_value=42 ) array[:] = np.arange(50) - session.commit("create array") + await session.commit_async("create array") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") root = zarr.group(store=session.store, overwrite=False) array = cast("zarr.Array[Any]", root["array"]) assert array[0] == 0 @@ -29,27 +30,27 @@ def reindex(idx: Iterable[int]) -> Iterable[int] | None: idx = list(idx) return [idx[0] - 4] if idx[0] >= 4 else None - session.reindex_array("/array", reindex) + await asyncio.to_thread(lambda: session.reindex_array("/array", reindex)) # we moved 4 chunks to the left, that's 8 array elements np.testing.assert_equal(array[0:42], np.arange(8, 50)) np.testing.assert_equal(array[42:], np.arange(42, 50)) async def test_shift_using_shift_by_offset() -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") root = zarr.group(store=session.store, overwrite=True) array = root.create_array( "array", shape=(50,), chunks=(2,), dtype="i4", fill_value=42 ) array[:] = np.arange(50) - session.commit("create array") + await session.commit_async("create array") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") root = zarr.group(store=session.store, overwrite=False) - session.shift_array("/array", (-4,)) + await asyncio.to_thread(lambda: session.shift_array("/array", (-4,))) array = cast("zarr.Array[Any]", root["array"]) # we moved 4 chunks to the left, that's 8 array elements np.testing.assert_equal(array[0:42], np.arange(8, 50)) @@ -57,30 +58,30 @@ async def test_shift_using_shift_by_offset() -> None: async def test_resize_and_shift_right() -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") root = zarr.group(store=session.store, overwrite=True) array = root.create_array( "array", shape=(50,), chunks=(2,), dtype="i4", fill_value=42 ) array[:] = np.arange(50) - session.commit("create array") + await session.commit_async("create array") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") root = zarr.group(store=session.store, overwrite=False) array = cast("zarr.Array[Any]", root["array"]) array.resize((100,)) assert array.shape == (100,) - session.shift_array("/array", (4,)) + await asyncio.to_thread(lambda: session.shift_array("/array", (4,))) np.testing.assert_equal(array[8:58], np.arange(50)) np.testing.assert_equal(array[0:8], np.arange(8)) assert np.all(array[58:] == 42) - session.commit("shifted") + await session.commit_async("shifted") # test still valid after commit - session = repo.readonly_session(branch="main") + session = await repo.readonly_session_async(branch="main") root = zarr.open_group(store=session.store, mode="r") array = cast("zarr.Array[Any]", root["array"]) assert array.shape == (100,) diff --git a/icechunk-python/tests/test_store.py b/icechunk-python/tests/test_store.py index 475ab86e6..bc6641aed 100644 --- a/icechunk-python/tests/test_store.py +++ b/icechunk-python/tests/test_store.py @@ -1,24 +1,25 @@ +import asyncio import json import numpy as np import icechunk as ic import zarr -from tests.conftest import parse_repo +from tests.conftest import parse_repo, parse_repo_async from zarr.core.buffer import cpu, default_buffer_prototype rng = np.random.default_rng(seed=12345) async def test_store_clear_metadata_list(any_spec_version: int | None) -> None: - repo = parse_repo("memory", "test", any_spec_version) - session = repo.writable_session("main") + repo = await parse_repo_async("memory", "test", any_spec_version) + session = await repo.writable_session_async("main") store = session.store zarr.group(store=store) - session.commit("created node /") + await session.commit_async("created node /") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store await store.clear() zarr.group(store=store) @@ -26,8 +27,8 @@ async def test_store_clear_metadata_list(any_spec_version: int | None) -> None: async def test_store_clear_chunk_list(any_spec_version: int | None) -> None: - repo = parse_repo("memory", "test", any_spec_version) - session = repo.writable_session("main") + repo = await parse_repo_async("memory", "test", any_spec_version) + session = await repo.writable_session_async("main") store = session.store array_kwargs = dict( @@ -35,9 +36,9 @@ async def test_store_clear_chunk_list(any_spec_version: int | None) -> None: ) group = zarr.group(store=store) group.create_array(**array_kwargs) # type: ignore[arg-type] - session.commit("created node /") + await session.commit_async("created node /") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store await store.clear() @@ -55,8 +56,8 @@ async def test_store_clear_chunk_list(any_spec_version: int | None) -> None: async def test_support_dimension_names_null(any_spec_version: int | None) -> None: - repo = parse_repo("memory", "test", any_spec_version) - session = repo.writable_session("main") + repo = await parse_repo_async("memory", "test", any_spec_version) + session = await repo.writable_session_async("main") store = session.store root = zarr.group(store=store) @@ -78,12 +79,12 @@ def test_doesnt_support_consolidated_metadata(any_spec_version: int | None) -> N async def test_with_readonly(any_spec_version: int | None) -> None: - repo = parse_repo("memory", "test", any_spec_version) - session = repo.readonly_session("main") + repo = await parse_repo_async("memory", "test", any_spec_version) + session = await repo.readonly_session_async("main") store = session.store assert store.read_only - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store writer = store.with_read_only(read_only=False) assert not writer._is_open @@ -99,29 +100,37 @@ async def test_with_readonly(any_spec_version: int | None) -> None: async def test_transaction(any_spec_version: int | None) -> None: - repo = parse_repo("memory", "test", any_spec_version) - cid1 = repo.lookup_branch("main") + repo = await parse_repo_async("memory", "test", any_spec_version) + cid1 = await repo.lookup_branch_async("main") + # TODO: test metadata, rebase_with, and rebase_tries kwargs - with repo.transaction("main", message="initialize group") as store: - assert not store.read_only - root = zarr.group(store=store) - root.attrs["foo"] = "bar" - cid2 = repo.lookup_branch("main") + def _run_transaction() -> None: + with repo.transaction("main", message="initialize group") as store: + assert not store.read_only + root = zarr.group(store=store) + root.attrs["foo"] = "bar" + + await asyncio.to_thread(_run_transaction) + cid2 = await repo.lookup_branch_async("main") assert cid1 != cid2, "Transaction did not commit changes" async def test_transaction_failed_no_commit(any_spec_version: int | None) -> None: - repo = parse_repo("memory", "test", any_spec_version) - cid1 = repo.lookup_branch("main") - try: + repo = await parse_repo_async("memory", "test", any_spec_version) + cid1 = await repo.lookup_branch_async("main") + + def _run_failing_transaction() -> None: with repo.transaction("main", message="initialize group") as store: assert not store.read_only root = zarr.group(store=store) root.attrs["foo"] = "bar" raise RuntimeError("Simulating an error to prevent commit") + + try: + await asyncio.to_thread(_run_failing_transaction) except RuntimeError: pass - cid2 = repo.lookup_branch("main") + cid2 = await repo.lookup_branch_async("main") assert cid1 == cid2, "Transaction committed changes despite error" diff --git a/icechunk-python/tests/test_timetravel.py b/icechunk-python/tests/test_timetravel.py index 0106c9a76..22db515ee 100644 --- a/icechunk-python/tests/test_timetravel.py +++ b/icechunk-python/tests/test_timetravel.py @@ -14,7 +14,11 @@ async def async_ancestry( repo: ic.Repository, **kwargs: str | None ) -> list[ic.SnapshotInfo]: - return [parent async for parent in repo.async_ancestry(**kwargs)] + return await asyncio.to_thread(lambda: list(repo.ancestry(**kwargs))) + + +async def first_ancestry(repo: ic.Repository, **kwargs: str | None) -> ic.SnapshotInfo: + return await asyncio.to_thread(lambda: next(repo.ancestry(**kwargs))) @pytest.mark.parametrize( @@ -235,41 +239,41 @@ def test_timetravel(using_flush: bool, any_spec_version: int | None) -> None: async def test_branch_reset(any_spec_version: int | None) -> None: config = ic.RepositoryConfig.default() config.inline_chunk_threshold_bytes = 1 - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), config=config, spec_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store group = zarr.group(store=store, overwrite=True) group.create_group("a") - prev_snapshot_id = session.commit("group a") + prev_snapshot_id = await session.commit_async("group a") - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store group = zarr.open_group(store=store) group.create_group("b") - last_commit = session.commit("group b") + last_commit = await session.commit_async("group b") keys = {k async for k in store.list()} assert "a/zarr.json" in keys assert "b/zarr.json" in keys with pytest.raises(ic.IcechunkError, match="branch update conflict"): - repo.reset_branch( + await repo.reset_branch_async( "main", prev_snapshot_id, from_snapshot_id="1CECHNKREP0F1RSTCMT0" ) - assert last_commit == repo.lookup_branch("main") + assert last_commit == await repo.lookup_branch_async("main") - repo.reset_branch("main", prev_snapshot_id, from_snapshot_id=last_commit) - assert prev_snapshot_id == repo.lookup_branch("main") + await repo.reset_branch_async("main", prev_snapshot_id, from_snapshot_id=last_commit) + assert prev_snapshot_id == await repo.lookup_branch_async("main") - session = repo.readonly_session("main") + session = await repo.readonly_session_async("main") store = session.store keys = {k async for k in store.list()} @@ -282,52 +286,52 @@ async def test_branch_reset(any_spec_version: int | None) -> None: async def test_tag_delete(any_spec_version: int | None) -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), spec_version=any_spec_version, ) - snap = repo.lookup_branch("main") - repo.create_tag("tag", snap) - repo.delete_tag("tag") + snap = await repo.lookup_branch_async("main") + await repo.create_tag_async("tag", snap) + await repo.delete_tag_async("tag") with pytest.raises(ic.IcechunkError): - repo.delete_tag("tag") + await repo.delete_tag_async("tag") with pytest.raises(ic.IcechunkError): - repo.create_tag("tag", snap) + await repo.create_tag_async("tag", snap) async def test_session_with_as_of(any_spec_version: int | None) -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), spec_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store times = [] group = zarr.group(store=store, overwrite=True) - sid = session.commit("root") - times.append(next(repo.ancestry(snapshot_id=sid)).written_at) + sid = await session.commit_async("root") + times.append((await first_ancestry(repo, snapshot_id=sid)).written_at) for i in range(5): - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store group = zarr.open_group(store=store) group.create_group(f"child {i}") - sid = session.commit(f"child {i}") - times.append(next(repo.ancestry(snapshot_id=sid)).written_at) + sid = await session.commit_async(f"child {i}") + times.append((await first_ancestry(repo, snapshot_id=sid)).written_at) - ancestry = list(p for p in repo.ancestry(branch="main")) + ancestry = await async_ancestry(repo, branch="main") assert len(ancestry) == 7 # initial + root + 5 children - store = repo.readonly_session("main", as_of=times[-1]).store + store = (await repo.readonly_session_async("main", as_of=times[-1])).store group = zarr.open_group(store=store, mode="r") for i, time in enumerate(times): - store = repo.readonly_session("main", as_of=time).store + store = (await repo.readonly_session_async("main", as_of=time)).store group = zarr.open_group(store=store, mode="r") expected_children = {f"child {j}" for j in range(i)} actual_children = {g[0] for g in group.members()} @@ -335,17 +339,17 @@ async def test_session_with_as_of(any_spec_version: int | None) -> None: async def test_default_commit_metadata(any_spec_version: int | None) -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), spec_version=any_spec_version, ) - repo.set_default_commit_metadata({"user": "test"}) - session = repo.writable_session("main") + await asyncio.to_thread(lambda: repo.set_default_commit_metadata({"user": "test"})) + session = await repo.writable_session_async("main") root = zarr.group(store=session.store, overwrite=True) root.create_group("child") - sid = session.commit("root") - snap = next(repo.ancestry(snapshot_id=sid)) + sid = await session.commit_async("root") + snap = await first_ancestry(repo, snapshot_id=sid) assert snap.metadata == {"user": "test"} @@ -391,34 +395,34 @@ def test_set_metadata() -> None: async def test_set_metadata_async() -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), ) - assert repo.metadata == {} + assert await repo.get_metadata_async() == {} await repo.set_metadata_async({"user": "test"}) assert await repo.get_metadata_async() == {"user": "test"} async def test_update_metadata() -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), ) - assert repo.metadata == {} + assert await repo.get_metadata_async() == {} - repo.update_metadata({"user": "test"}) - assert repo.get_metadata() == {"user": "test"} - repo.update_metadata({"foo": 42}) - assert repo.get_metadata() == {"user": "test", "foo": 42} - repo.update_metadata({"foo": 43}) - assert repo.get_metadata() == {"user": "test", "foo": 43} + await repo.update_metadata_async({"user": "test"}) + assert await repo.get_metadata_async() == {"user": "test"} + await repo.update_metadata_async({"foo": 42}) + assert await repo.get_metadata_async() == {"user": "test", "foo": 42} + await repo.update_metadata_async({"foo": 43}) + assert await repo.get_metadata_async() == {"user": "test", "foo": 43} async def test_update_metadata_async() -> None: - repo = ic.Repository.create( + repo = await ic.Repository.create_async( storage=ic.in_memory_storage(), ) - assert repo.metadata == {} + assert await repo.get_metadata_async() == {} await repo.update_metadata_async({"user": "test"}) assert await repo.get_metadata_async() == {"user": "test"} @@ -450,7 +454,7 @@ async def test_timetravel_async(using_flush: bool, any_spec_version: int | None) air_temp[:, :] = 42 assert air_temp[200, 6] == 42 - status = session.status() + status = await asyncio.to_thread(lambda: session.status()) assert status.new_groups == {"/"} assert status.new_arrays == {"/air_temp"} assert list(status.updated_chunks.keys()) == ["/air_temp"] @@ -542,9 +546,7 @@ async def test_timetravel_async(using_flush: bool, any_spec_version: int | None) air_temp = cast("zarr.core.array.Array[Any]", group["air_temp"]) assert air_temp[200, 6] == 90 - parents = [ - parent async for parent in repo.async_ancestry(snapshot_id=feature_snapshot_id) - ] + parents = await async_ancestry(repo, snapshot_id=feature_snapshot_id) assert [snap.message for snap in parents] == [ "commit 3", "commit 2", @@ -552,13 +554,16 @@ async def test_timetravel_async(using_flush: bool, any_spec_version: int | None) "Repository initialized", ] assert parents[-1].id == "1CECHNKREP0F1RSTCMT0" - assert [len(repo.list_manifest_files(snap.id)) for snap in parents] == [1, 1, 1, 0] + assert [len(await repo.list_manifest_files_async(snap.id)) for snap in parents] == [ + 1, + 1, + 1, + 0, + ] assert sorted(parents, key=lambda p: p.written_at) == list(reversed(parents)) assert len(set([snap.id for snap in parents])) == 4 - assert [parent async for parent in repo.async_ancestry(tag="v1.0")] == parents - assert [ - parent async for parent in repo.async_ancestry(branch="feature-not-dead") - ] == parents + assert await async_ancestry(repo, tag="v1.0") == parents + assert await async_ancestry(repo, branch="feature-not-dead") == parents diff = await repo.diff_async(to_tag="v1.0", from_snapshot_id=parents[-1].id) assert diff.new_groups == {"/"} @@ -602,7 +607,7 @@ async def test_timetravel_async(using_flush: bool, any_spec_version: int | None) "air_temp", shape=(1000, 1000), chunks=(100, 100), dtype="i4", overwrite=True ) assert ( - repr(session.status()) + repr(await asyncio.to_thread(lambda: session.status())) == """\ Arrays created: /air_temp @@ -622,11 +627,11 @@ async def test_timetravel_async(using_flush: bool, any_spec_version: int | None) tag_snapshot_id = await repo.lookup_tag_async("v1.0") assert tag_snapshot_id == feature_snapshot_id - actual = next(iter([parent async for parent in repo.async_ancestry(tag="v1.0")])) + actual = (await async_ancestry(repo, tag="v1.0"))[0] assert actual == await repo.lookup_snapshot_async(actual.id) if any_spec_version is None or any_spec_version > 1: - ops = [type(op) for op in repo.ops_log()] + ops = await asyncio.to_thread(lambda: [type(op) for op in repo.ops_log()]) flush_or_commit = ( [ic.NewCommitUpdate] if not using_flush @@ -711,7 +716,7 @@ async def test_session_with_as_of_async(any_spec_version: int | None) -> None: times = [] group = zarr.group(store=session.store, overwrite=True) sid = await session.commit_async("root") - to_append = await anext(repo.async_ancestry(snapshot_id=sid)) + to_append = await first_ancestry(repo, snapshot_id=sid) times.append(to_append.written_at) for i in range(5): @@ -719,10 +724,10 @@ async def test_session_with_as_of_async(any_spec_version: int | None) -> None: group = zarr.open_group(store=session.store) group.create_group(f"child {i}") sid = await session.commit_async(f"child {i}") - to_append = await anext(repo.async_ancestry(snapshot_id=sid)) + to_append = await first_ancestry(repo, snapshot_id=sid) times.append(to_append.written_at) - ancestry = [p async for p in repo.async_ancestry(branch="main")] + ancestry = await async_ancestry(repo, branch="main") assert len(ancestry) == 7 # initial + root + 5 children store = (await repo.readonly_session_async("main", as_of=times[-1])).store @@ -926,7 +931,7 @@ async def test_repository_lifecycle_async(any_spec_version: int | None) -> None: new_config_sync = ic.RepositoryConfig.default() new_config_sync.inline_chunk_threshold_bytes = 2048 - reopened_repo_sync = repo.reopen(config=new_config_sync) + reopened_repo_sync = await repo.reopen_async(config=new_config_sync) assert reopened_repo_sync.config.inline_chunk_threshold_bytes == 2048 # Test open_or_create_async with new storage (should create) @@ -965,7 +970,7 @@ async def test_rewrite_manifests_async(any_spec_version: int | None) -> None: second_commit = await session.commit_async("more data") # Get initial ancestry to verify commits exist - initial_ancestry = [snap async for snap in repo.async_ancestry(branch="main")] + initial_ancestry = await async_ancestry(repo, branch="main") assert len(initial_ancestry) >= 3 # initial + first_commit + second_commit # Test rewrite_manifests_async @@ -978,7 +983,7 @@ async def test_rewrite_manifests_async(any_spec_version: int | None) -> None: assert rewrite_commit != second_commit # Verify ancestry after rewrite - new_ancestry = [snap async for snap in repo.async_ancestry(branch="main")] + new_ancestry = await async_ancestry(repo, branch="main") extra_commits = 1 if any_spec_version == 1 else 0 # we do amend in IC2+ assert len(new_ancestry) == len(initial_ancestry) + extra_commits assert new_ancestry[0].message == "rewritten manifests" @@ -1089,7 +1094,7 @@ async def test_long_ops_log(spec_version: int | None) -> None: snap = await repo.lookup_branch_async("main") for i in range(1, NUM_BRANCHES + 1): await repo.create_branch_async(str(i), snap) - updates = [update async for update in repo.ops_log_async()] + updates = await asyncio.to_thread(lambda: list(repo.ops_log())) assert len(updates) == NUM_BRANCHES + 1 t = type(updates[0]) diff --git a/icechunk-python/tests/test_virtual_ref.py b/icechunk-python/tests/test_virtual_ref.py index 2523324d8..9aa2f4a55 100644 --- a/icechunk-python/tests/test_virtual_ref.py +++ b/icechunk-python/tests/test_virtual_ref.py @@ -60,13 +60,13 @@ async def test_write_minio_virtual_refs( ) # Open the store - repo = Repository.open_or_create( + repo = await Repository.open_or_create_async( storage=in_memory_storage(), config=config, authorize_virtual_chunk_access=credentials, create_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store array = zarr.create_array( @@ -210,7 +210,7 @@ async def test_write_minio_virtual_refs( # TODO: we should include the key and other info in the exception await store.get("c/0/0/2", prototype=buffer_prototype) - all_locations = set(session.all_virtual_chunk_locations()) + all_locations = set(await session.all_virtual_chunk_locations_async()) assert f"s3://testbucket/{prefix}/non-existing" in all_locations assert f"s3://testbucket/{prefix}/chunk-1" in all_locations assert f"s3://testbucket/{prefix}/chunk-2" in all_locations @@ -226,7 +226,7 @@ async def test_write_minio_virtual_refs( with pytest.raises(IcechunkError, match="chunk has changed"): await store.get("c/3/0/1", prototype=buffer_prototype) - _snapshot_id = session.commit("Add virtual refs") + _snapshot_id = await session.commit_async("Add virtual refs") @pytest.mark.parametrize( @@ -257,13 +257,13 @@ async def test_public_virtual_refs( container = VirtualChunkContainer(url_prefix, store_config) config.set_virtual_chunk_container(container) - repo = Repository.open_or_create( + repo = await Repository.open_or_create_async( storage=local_filesystem_storage(f"{tmpdir}/virtual-{container_type}"), config=config, authorize_virtual_chunk_access={url_prefix: None}, create_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store root = zarr.Group.from_store(store=store, zarr_format=3) diff --git a/icechunk-python/tests/test_zarr/test_store/test_core.py b/icechunk-python/tests/test_zarr/test_store/test_core.py index b6b8ace07..29bac7c90 100644 --- a/icechunk-python/tests/test_zarr/test_store/test_core.py +++ b/icechunk-python/tests/test_zarr/test_store/test_core.py @@ -1,16 +1,16 @@ from icechunk import IcechunkStore -from tests.conftest import parse_repo +from tests.conftest import parse_repo_async from zarr.storage._common import make_store_path async def test_make_store_path(any_spec_version: int | None) -> None: # Memory store - repo = parse_repo( + repo = await parse_repo_async( "memory", path="", spec_version=any_spec_version, ) - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store store_path = await make_store_path(store) assert isinstance(store_path.store, IcechunkStore) diff --git a/icechunk-python/tests/test_zarr/test_store/test_icechunk_store.py b/icechunk-python/tests/test_zarr/test_store/test_icechunk_store.py index eb27153e1..c5b2e217e 100644 --- a/icechunk-python/tests/test_zarr/test_store/test_icechunk_store.py +++ b/icechunk-python/tests/test_zarr/test_store/test_icechunk_store.py @@ -55,21 +55,21 @@ def store_kwargs(self, tmpdir: Path) -> dict[str, Any]: @pytest.fixture async def store(self, store_kwargs: dict[str, Any]) -> IcechunkStore: read_only = store_kwargs.pop("read_only") - repo = Repository.open_or_create(**store_kwargs) + repo = await Repository.open_or_create_async(**store_kwargs) if read_only: - session = repo.readonly_session(branch="main") + session = await repo.readonly_session_async(branch="main") else: - session = repo.writable_session("main") + session = await repo.writable_session_async("main") return session.store @pytest.fixture async def store_not_open(self, store_kwargs: dict[str, Any]) -> IcechunkStore: read_only = store_kwargs.pop("read_only") - repo = Repository.open_or_create(**store_kwargs) + repo = await Repository.open_or_create_async(**store_kwargs) if read_only: - session = repo.readonly_session(branch="main") + session = await repo.readonly_session_async(branch="main") else: - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store store._is_open = False @@ -98,11 +98,11 @@ def test_store_repr(self, store: IcechunkStore) -> None: async def test_store_open_read_only( self, store: IcechunkStore, store_kwargs: dict[str, Any], read_only: bool ) -> None: - repo = Repository.open(**store_kwargs) + repo = await Repository.open_async(**store_kwargs) if read_only: - session = repo.readonly_session(branch="main") + session = await repo.readonly_session_async(branch="main") else: - session = repo.writable_session("main") + session = await repo.writable_session_async("main") store = session.store assert store._is_open assert store.read_only == read_only @@ -111,8 +111,8 @@ async def test_read_only_store_raises( # type: ignore[override] self, store: IcechunkStore, store_kwargs: dict[str, Any] ) -> None: kwargs = {**store_kwargs} - repo = Repository.open(**kwargs) - session = repo.readonly_session(branch="main") + repo = await Repository.open_async(**kwargs) + session = await repo.readonly_session_async(branch="main") store = session.store assert store.read_only