Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions .github/workflows/run_notebooks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,9 @@ jobs:
uses: actions/cache/restore@v4
with:
path: docs/tutorials/ehrapy_data
key: ehrapy-data-notebook-${{ matrix.notebook }}-v1
key: ehrapy-data-notebook-${{ matrix.notebook }}-v2

- name: Run ${{ matrix.notebook }} Notebook
env:
EHRDATA_DOWNLOAD_MAX_RETRIES: "5"
EHRDATA_DOWNLOAD_RETRY_DELAY: "10"
run: jupyter nbconvert --to notebook --execute ${{ matrix.notebook }}

# Only persist after a green run that had to download (cache miss), so a partial/failed download is never saved under the (immutable) key.
Expand All @@ -59,4 +56,4 @@ jobs:
uses: actions/cache/save@v4
with:
path: docs/tutorials/ehrapy_data
key: ehrapy-data-notebook-${{ matrix.notebook }}-v1
key: ehrapy-data-notebook-${{ matrix.notebook }}-v2
10 changes: 5 additions & 5 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,22 +84,22 @@ jobs:
uses: actions/cache/restore@v4
with:
path: ehrapy_data
key: ehrapy-data-tests-v1
key: ehrapy-data-tests-v2
- name: run tests using hatch
env:
MPLBACKEND: agg
PLATFORM: ${{ matrix.os }}
DISPLAY: :42
EHRDATA_DOWNLOAD_MAX_RETRIES: "5"
EHRDATA_DOWNLOAD_RETRY_DELAY: "10"
run: uvx hatch run ${{ matrix.env.name }}:run-cov -v --color=yes -n auto
# --dist loadgroup so tests sharing a dataset download path (marked with @pytest.mark.xdist_group) run on the
# same worker and don't download/extract the same dataset concurrently on a cold cache.
run: uvx hatch run ${{ matrix.env.name }}:run-cov -v --color=yes -n auto --dist loadgroup
# Only persist the cache after a green run that had to download (cache miss), so a partial/failed download is never saved under the (immutable) key.
- name: Save datasets cache
if: steps.data-cache.outputs.cache-hit != 'true' && success()
uses: actions/cache/save@v4
with:
path: ehrapy_data
key: ehrapy-data-tests-v1
key: ehrapy-data-tests-v2
- name: generate coverage report
run: |
# See https://coverage.readthedocs.io/en/latest/config.html#run-patch
Expand Down
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ and this project adheres to [Semantic Versioning][].
## [Unreleased]

### Fixed
- Dataset downloads in `ehrdata.dt` no longer fail silently: the final retry now raises instead of returning a path that was never downloaded, so unreachable hosts (e.g. physionet.org) surface a clear error instead of a downstream `FileNotFoundError`. Retries are configurable via the `EHRDATA_DOWNLOAD_MAX_RETRIES` / `EHRDATA_DOWNLOAD_RETRY_DELAY` environment variables, and CI now caches downloaded datasets so flaky upstream hosts don't break the test and notebook workflows. ([#250](https://github.com/theislab/ehrdata/pull/250)) @eroell
- {func}`~ehrdata.infer_feature_types` can handle `EHRData` objects with `X` as `None`. ([#246](https://github.com/theislab/ehrdata/pull/246)) @sueoglu
- {func}`~ehrdata.dt.physionet2019` no longer raises a shape mismatch on the full dataset: persons whose dynamic measurements all fall outside the observation window are now padded with missing values instead of being dropped from the time series tensor. ([#251](https://github.com/theislab/ehrdata/issues/251)) @eroell

### Maintenance
- CI now caches downloaded datasets used by `ehrdata.dt` to reduce flaky upstream hosts (e.g. physionet.org) breaking the test and notebook workflows. ([#250](https://github.com/theislab/ehrdata/pull/250)) @eroell
- Dataset downloads now use [pooch](https://www.fatiando.org/pooch/) instead of a custom `requests`-based downloader, aligning with the scverse ecosystem and providing caching out of the box. ([#251](https://github.com/theislab/ehrdata/issues/251)) @eroell

## [0.2.1]

Expand Down
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ dependencies = [
"anndata>=0.12.6",
"duckdb",
"fast-array-utils[sparse]",
"filelock",
"requests",
"pooch",
"rich",
"tqdm",

@Zethson Zethson Jun 4, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need tqdm as dependency because rich itself can already render the progressbar.

https://github.com/scverse/pertpy/blob/main/pertpy/data/_dataloader.py#L12

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opted for a simpler progressbar with tqdm, and not rich

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like this we need a new dependency tho and it's not necessary

"xarray",
"zarr>=3",
]
Expand Down Expand Up @@ -178,6 +178,10 @@ ini_options.xfail_strict = true
ini_options.addopts = [
"--import-mode=importlib", # allow using test files with same name
]
# Provided by pytest-xdist (CI only); registered here so runs without xdist don't warn about an unknown marker.
ini_options.markers = [
"xdist_group: run tests sharing a dataset download path on the same xdist worker (needs --dist loadgroup)",
]

[tool.coverage]
run.omit = [
Expand Down
128 changes: 41 additions & 87 deletions src/ehrdata/dt/_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,27 @@
from __future__ import annotations

import logging
import os
import shutil
import tempfile
import time
import warnings
from pathlib import Path, PurePath
from typing import Literal, get_args
from urllib.parse import urlparse

import requests
from filelock import FileLock
from requests.exceptions import RequestException
from rich.progress import Progress

from ehrdata._logger import logger

with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="IProgress not found")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we really do this?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an ugly warning is raised during import without this, since right now I use tqdm, instead of rich+ipywidgets

import pooch

pooch.get_logger().setLevel(logging.WARNING)

COMPRESSION_FORMATS = Literal["tar.gz", "gztar", "zip", "tar", "gz", "bz", "xz"]
COMPRESSION_FORMATS_LIST = list(get_args(COMPRESSION_FORMATS))
RAW_FORMATS = Literal["csv", "txt", "parquet", "h5ad", "zarr"]
RAW_FORMATS_LIST = list(get_args(RAW_FORMATS))

# Retry defaults. CI can override these via environment variables to fail fast instead of waiting through
# the full exponential backoff; regular users never need to set them.
DEFAULT_MAX_RETRIES = max(1, int(os.environ.get("EHRDATA_DOWNLOAD_MAX_RETRIES", "5")))
DEFAULT_RETRY_DELAY = int(os.environ.get("EHRDATA_DOWNLOAD_RETRY_DELAY", "10"))


def _download(
url: str,
Expand All @@ -36,11 +33,12 @@ def _download(
*,
overwrite: bool = False,
timeout: int = 60,
max_retries: int = DEFAULT_MAX_RETRIES,
retry_delay: int = DEFAULT_RETRY_DELAY,
) -> None | Path: # pragma: no cover
"""Downloads a file irrespective of format.

The download itself, including retries and caching, is delegated to
`pooch <https://www.fatiando.org/pooch/>`_, in line with the scverse ecosystem.

Args:
url: URL to download.
output_filename: Name of the file to download. If not specified, the file name will be inferred from the URL.
Expand All @@ -50,8 +48,6 @@ def _download(
block_size: Block size for downloads in bytes.
overwrite: Whether to overwrite existing files.
timeout: Request timeout in seconds.
max_retries: Maximum number of download attempts before giving up. Defaults to 5.
retry_delay: Base delay in seconds between attempts (grows exponentially). Defaults to 10.
"""

def _sanitize_filename(filename: str) -> str:
Expand All @@ -69,6 +65,7 @@ def _remove_archive_extension(filename: str) -> str:
output_path = tempfile.gettempdir()

output_path = Path(output_path)
output_path.mkdir(parents=True, exist_ok=True)

url_filename = PurePath(urlparse(url).path).name
suffix = url_filename.split(".")[-1]
Expand All @@ -83,83 +80,40 @@ def _remove_archive_extension(filename: str) -> str:
file_ending = suffix

if file_ending in RAW_FORMATS_LIST:
download_dir = output_path
raw_data_output_path = output_path / output_filename
path_to_check = raw_data_output_path
elif file_ending in COMPRESSION_FORMATS_LIST:
tmpdir = tempfile.mkdtemp()
raw_data_output_path = Path(tmpdir) / output_filename
download_dir = Path(tempfile.mkdtemp())
raw_data_output_path = download_dir / output_filename
path_to_check = output_path / _remove_archive_extension(output_filename)
else:
msg = f"Unknown file format: {file_ending}"
raise RuntimeError(msg)

lock_path = f"{path_to_check}.lock"
with FileLock(lock_path, timeout=600):
if path_to_check.exists():
warning = f"File {path_to_check} already exists!"
if not overwrite:
return path_to_check
else:
logger.warning(f"{warning} Overwriting...")

temp_filename = f"{raw_data_output_path}.part"

retry_count = 0
while retry_count < max_retries:
try:
headers = {"User-Agent": "ehrdata/1.0.0 (https://github.com/theislab/ehrdata)"}
head_response = requests.head(url, timeout=timeout, headers=headers)
head_response.raise_for_status()
content_length = int(head_response.headers.get("content-length", 0))
free_space = shutil.disk_usage(output_path).free

if content_length > free_space:
msg = f"Insufficient disk space. Need {content_length} bytes, but only {free_space} available."
raise OSError(msg)

response = requests.get(url, stream=True, headers=headers, timeout=timeout)
response.raise_for_status()
total = int(response.headers.get("content-length", 0))

with Progress(refresh_per_second=5) as progress:
task = progress.add_task("[red]Downloading...", total=total)
with Path(temp_filename).open("wb") as file:
for data in response.iter_content(block_size):
file.write(data)
progress.update(task, advance=len(data))
progress.update(task, completed=total, refresh=True)

Path(temp_filename).replace(raw_data_output_path)

if file_ending in COMPRESSION_FORMATS_LIST:
shutil.unpack_archive(raw_data_output_path, output_path)

return path_to_check

except (OSError, RequestException) as e:
retry_count += 1
if retry_count < max_retries:
# Exponential backoff: base delay * 2^(attempt-1)
backoff_delay = retry_delay * (2 ** (retry_count - 1))
logger.warning(
f"Download attempt {retry_count}/{max_retries} failed: {e!s}. Retrying in {backoff_delay} seconds..."
)
time.sleep(backoff_delay)
else:
# Final attempt failed: surface the error instead of silently returning a missing path.
logger.error(f"Download failed after {max_retries} attempts: {e!s}")
if Path(temp_filename).exists():
Path(temp_filename).unlink(missing_ok=True)
raise

except Exception as e:
logger.error(f"Download failed: {e!s}")
if Path(temp_filename).exists():
Path(temp_filename).unlink(missing_ok=True)
raise
finally:
if Path(temp_filename).exists():
Path(temp_filename).unlink(missing_ok=True)
Path(lock_path).unlink(missing_ok=True)

return path_to_check
if path_to_check.exists():
warning = f"File {path_to_check} already exists!"
if not overwrite:
return path_to_check
logger.warning(f"{warning} Overwriting...")
# pooch does not re-fetch an existing file when no hash is given, so remove it to force a download.
if raw_data_output_path.exists():
raw_data_output_path.unlink()

pooch.retrieve(
url=url,
known_hash=None,
fname=output_filename,
path=str(download_dir),
downloader=pooch.HTTPDownloader(
progressbar=True,
chunk_size=block_size,
timeout=timeout,
headers={"User-Agent": "ehrdata/1.0.0 (https://github.com/theislab/ehrdata)"},
),
)

if file_ending in COMPRESSION_FORMATS_LIST:
shutil.unpack_archive(raw_data_output_path, output_path)

return path_to_check
39 changes: 6 additions & 33 deletions src/ehrdata/dt/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
import pandas as pd

from ehrdata._logger import logger
from ehrdata.core.constants import DEFAULT_DATA_PATH, DEFAULT_TEM_LAYER_NAME
from ehrdata.dt._dataloader import _download
from ehrdata.io import read_csv, read_h5ad
Expand Down Expand Up @@ -301,8 +300,6 @@ def _setup_eunomia_datasets(
)

if nested_omop_tables_folder:
if len(list((data_path / nested_omop_tables_folder).glob("*.csv"))) > 0:
logger.info(f"Moving files from {data_path / nested_omop_tables_folder} to {data_path}")
for file_path in (data_path / nested_omop_tables_folder).glob("*.csv"):
shutil.move(file_path, data_path)

Expand Down Expand Up @@ -856,6 +853,12 @@ def _create_edata_from_physionet_long_format(

xa = df_long.set_index(["RecordID", "Parameter", "interval_step"]).to_xarray()

# persons whose dynamic measurements all fall outside the observation window are dropped from the long->xarray pivot;
# reindex to every person in obs (in obs order) so the layer stays aligned with obs and missing persons are padded with missing values instead of producing a shape mismatch
xa = xa.reindex(
RecordID=obs.index.values,
fill_value=np.nan,
)
# since NaNs are dropped, it can happen that a Parameter is completely dropped when it has no values for the subset of persons considered
# to provide a full set of Parameters everytime, we reindex to add the missing Parameters back in, just with missing values
xa = xa.reindex(
Expand Down Expand Up @@ -973,26 +976,11 @@ def diabetes_130_raw(
>>> import ehrdata as ed
>>> edata = ed.dt.diabetes_130_raw()
"""
import os

# Use more aggressive retry settings in CI environments
is_ci = os.getenv("CI", "false").lower() == "true"
download_kwargs = {}
if is_ci:
download_kwargs.update(
{
"timeout": 120,
"max_retries": 8,
"retry_delay": 15,
}
)

_download(
url="https://exampledata.scverse.org/ehrapy/diabetes_130_raw.csv",
output_path=DEFAULT_DATA_PATH,
output_filename="diabetes_130_raw.csv",
raw_format="csv",
**download_kwargs,
)
adata = read_csv(
filename=f"{DEFAULT_DATA_PATH}/diabetes_130_raw.csv",
Expand All @@ -1018,26 +1006,11 @@ def diabetes_130_fairlearn(
>>> import ehrdata as ed
>>> edata = ed.dt.diabetes_130_fairlearn()
"""
import os

# Use more aggressive retry settings in CI environments
is_ci = os.getenv("CI", "false").lower() == "true"
download_kwargs = {}
if is_ci:
download_kwargs.update(
{
"timeout": 120,
"max_retries": 8,
"retry_delay": 15,
}
)

_download(
url="https://exampledata.scverse.org/ehrapy/diabetes_130_fairlearn.csv",
output_path=DEFAULT_DATA_PATH,
output_filename="diabetes_130_fairlearn.csv",
raw_format="csv",
**download_kwargs,
)
edata = read_csv(
filename=f"{DEFAULT_DATA_PATH}/diabetes_130_fairlearn.csv",
Expand Down
5 changes: 5 additions & 0 deletions tests/dt/test_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def duckdb_connection():
con.close()


@pytest.mark.xdist_group(name="dataset_mimic_iv_omop")
def test_mimic_iv_omop():
duckdb_connection = duckdb.connect()
ed.dt.mimic_iv_omop(backend_handle=duckdb_connection)
Expand Down Expand Up @@ -78,6 +79,7 @@ def test_synthea27nj_omop():
duckdb_connection.close()


@pytest.mark.xdist_group(name="dataset_physionet2012")
def test_physionet2012():
edata = ed.dt.physionet2012(layer=DEFAULT_TEM_LAYER_NAME)
assert edata.shape == (11988, 37, 48)
Expand Down Expand Up @@ -113,6 +115,7 @@ def test_physionet2012():
assert np.isclose(edata[edata.obs.index.get_loc("152871"), "HR", 28].layers[DEFAULT_TEM_LAYER_NAME].item(), 68)


@pytest.mark.xdist_group(name="dataset_physionet2012")
def test_physionet2012_arguments():
edata = ed.dt.physionet2012(
layer=DEFAULT_TEM_LAYER_NAME,
Expand All @@ -129,6 +132,7 @@ def test_physionet2012_arguments():
assert edata.var.shape == (37, 1)


@pytest.mark.xdist_group(name="dataset_physionet2019")
def test_physionet2019():
edata = ed.dt.physionet2019(layer=DEFAULT_TEM_LAYER_NAME, n_samples=10)
assert edata.shape == (10, 35, 48)
Expand All @@ -152,6 +156,7 @@ def test_physionet2019():
)


@pytest.mark.xdist_group(name="dataset_physionet2019")
def test_physionet2019_arguments():
edata = ed.dt.physionet2019(
layer=DEFAULT_TEM_LAYER_NAME,
Expand Down
Loading
Loading