Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions cellarium/ml/data/distributed_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from boltons.cacheutils import LRU
from braceexpand import braceexpand

from cellarium.ml.data.fileio import read_h5ad_file
from cellarium.ml.data.fileio import backed_mode_default, backed_mode_type, read_h5ad_file
from cellarium.ml.data.schema import AnnDataSchema


Expand All @@ -28,6 +28,8 @@ class getattr_mode:

_GETATTR_MODE = getattr_mode()

allowed_backed_modes = [None, True, False, "r"]


@contextmanager
def lazy_getattr():
Expand Down Expand Up @@ -155,6 +157,12 @@ class DistributedAnnDataCollection(AnnCollection):
obs_columns_to_validate:
Subset of columns to validate in the :attr:`obs` attribute.
If ``None``, all columns are validated.
backed:
Optional backing mode for the h5ad files. ``'r'`` will leave count matrices
on disk until specific cell indices are queried, enabling the use of very large
h5ad files, while ``None`` will load entire count matrices from individual h5ad files
into cached memory as needed: a strategy that necessitates smaller chunked h5ad files.
See :func:`anndata.read_h5ad` for details on backing modes.
"""

def __init__(
Expand All @@ -171,6 +179,7 @@ def __init__(
convert: ConvertType | None = None,
indices_strict: bool = True,
obs_columns_to_validate: Sequence[str] | None = None,
backed: backed_mode_type = backed_mode_default,
):
self.filenames = list(braceexpand(filenames) if isinstance(filenames, str) else filenames)
if (shard_size is None) and (last_shard_size is not None):
Expand All @@ -192,8 +201,11 @@ def __init__(
self.cache = LRU(max_cache_size)
self.max_cache_size = max_cache_size
self.cache_size_strictly_enforced = cache_size_strictly_enforced
if backed not in allowed_backed_modes:
raise ValueError(f"Invalid backed mode: {backed}. Choose from {allowed_backed_modes}")
self.backed = backed
# schema
adata0 = self.cache[self.filenames[0]] = read_h5ad_file(self.filenames[0])
adata0 = self.cache[self.filenames[0]] = read_h5ad_file(self.filenames[0], backed=backed)
if len(adata0) != limits[0]:
raise ValueError(
f"The number of cells in the first anndata file ({len(adata0)}) "
Expand All @@ -203,7 +215,7 @@ def __init__(
self.schema = AnnDataSchema(adata0, obs_columns_to_validate)
# lazy anndatas
lazy_adatas = [
LazyAnnData(filename, (start, end), self.schema, self.cache)
LazyAnnData(filename, (start, end), self.schema, self.cache, backed=backed)
for start, end, filename in zip([0] + limits, limits, self.filenames)
]
# use filenames as default keys
Expand Down Expand Up @@ -298,10 +310,10 @@ def __getstate__(self):
def __setstate__(self, state):
self.__dict__.update(state)
self.cache = LRU(self.max_cache_size)
adata0 = self.cache[self.filenames[0]] = read_h5ad_file(self.filenames[0])
adata0 = self.cache[self.filenames[0]] = read_h5ad_file(self.filenames[0], backed=self.backed)
self.schema = AnnDataSchema(adata0, self.obs_columns_to_validate)
self.adatas = [
LazyAnnData(filename, (start, end), self.schema, self.cache)
LazyAnnData(filename, (start, end), self.schema, self.cache, backed=self.backed)
for start, end, filename in zip([0] + self.limits, self.limits, self.filenames)
]
self.obs_names = pd.Index([f"cell_{i}" for i in range(self.limits[-1])])
Expand All @@ -323,6 +335,11 @@ class LazyAnnData:
Schema used as a reference for lazy attributes.
cache:
Shared LRU cache storing buffered anndatas.
backed:
Optional backing mode for the anndata. ``'r'`` will leave count matrix
on disk, while ``None`` will load count matrix in memory (when the anndata is
cached by calling the `.adata` property).
See :func:`anndata.read_h5ad` for details on backing modes.
"""

_lazy_attrs = ["obs", "obsm", "layers", "var", "varm", "varp", "var_names"]
Expand All @@ -343,10 +360,14 @@ def __init__(
limits: tuple[int, int],
schema: AnnDataSchema,
cache: LRU | None = None,
backed: backed_mode_type = backed_mode_default,
):
self.filename = filename
self.limits = limits
self.schema = schema
if backed not in allowed_backed_modes:
raise ValueError(f"Invalid backed mode: {backed}. Choose from {allowed_backed_modes}")
self.backed = backed
if cache is None:
cache = LRU()
self.cache = cache
Expand Down Expand Up @@ -382,16 +403,16 @@ def cached(self) -> bool:

@property
def adata(self) -> AnnData:
"""Return backed anndata from the filename"""
"""Return anndata from the filename"""
try:
adata = self.cache[self.filename]
except KeyError:
# fetch anndata
adata = read_h5ad_file(self.filename)
adata = read_h5ad_file(self.filename, backed=self.backed)
# validate anndata
if self.n_obs != adata.n_obs:
raise ValueError(
"Expected `n_obs` for LazyAnnData object and backed anndata to match "
"Expected `n_obs` for LazyAnnData object and loaded anndata to match "
f"but found {self.n_obs} and {adata.n_obs}, respectively."
)
self.schema.validate_anndata(adata)
Expand Down Expand Up @@ -426,8 +447,9 @@ def __repr__(self) -> str:
buffered = "Cached "
else:
buffered = ""
backed_at = f" backed at {str(self.filename)!r}"
descr = f"{buffered}LazyAnnData object with n_obs × n_vars = {self.n_obs} × {self.n_vars}{backed_at}"
located_at = f" referencing {str(self.filename)!r}"
backed = " in backed mode" if (self.backed in [True, "r"]) else " in memory mode"
descr = f"{buffered}LazyAnnData object with n_obs × n_vars = {self.n_obs} × {self.n_vars}{located_at}{backed}"
if self.cached:
for attr in self._all_attrs:
keys = getattr(self, attr).keys()
Expand Down
64 changes: 53 additions & 11 deletions cellarium/ml/data/fileio.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

import os
import re
import shutil
import tempfile
import urllib.request
from typing import Literal

from anndata import AnnData, read_h5ad
from google.cloud.storage import Client

url_schemes = ("http:", "https:", "ftp:")
backed_mode_type = Literal["r"] | bool | None
backed_mode_default: backed_mode_type = "r"


def read_h5ad_gcs(filename: str, storage_client: Client | None = None) -> AnnData:
def read_h5ad_gcs(
filename: str,
storage_client: Client | None = None,
backed: backed_mode_type = backed_mode_default,
) -> AnnData:
r"""
Read ``.h5ad``-formatted hdf5 file from the Google Cloud Storage.

Expand All @@ -22,6 +30,9 @@ def read_h5ad_gcs(filename: str, storage_client: Client | None = None) -> AnnDat

Args:
filename: Path to the data file in Cloud Storage.
backed: See :func:`anndata.read_h5ad` for details on backed mode.
['r', True] will load in backed mode instead of fully loading into memory.
[False, None] will use in-memory mode.
"""
if not filename.startswith("gs:"):
raise ValueError("The filename must start with 'gs:' protocol name.")
Expand All @@ -35,11 +46,20 @@ def read_h5ad_gcs(filename: str, storage_client: Client | None = None) -> AnnDat
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(blob_name)

with blob.open("rb") as f:
return read_h5ad(f)
# write to a named temporary file
with tempfile.NamedTemporaryFile(suffix=".h5ad", delete=False) as tmp_file:
temp_path = tmp_file.name
blob.download_to_file(tmp_file)
try:
return read_h5ad(temp_path, backed=backed)
finally:
try:
os.unlink(temp_path) # clean up the temp file
except OSError:
pass # if there's an error during cleanup, continue


def read_h5ad_url(filename: str) -> AnnData:
def read_h5ad_url(filename: str, backed: backed_mode_type = backed_mode_default) -> AnnData:
r"""
Read ``.h5ad``-formatted hdf5 file from the URL.

Expand All @@ -48,37 +68,59 @@ def read_h5ad_url(filename: str) -> AnnData:
>>> adata = read_h5ad_url(
... "https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_0.h5ad"
... )
>>> adata = read_h5ad_url(
... "https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_0.h5ad",
... backed='r'
... )

Args:
filename: URL of the data file.
backed: See :func:`anndata.read_h5ad` for details on backed mode.
['r', True] will load in backed mode instead of fully loading into memory.
[False, None] will use in-memory mode.
"""
if not any(filename.startswith(scheme) for scheme in url_schemes):
raise ValueError("The filename must start with 'http:', 'https:', or 'ftp:' protocol name.")
with urllib.request.urlopen(filename) as response:
with tempfile.TemporaryFile() as tmp_file:

# write to a named temporary file
with tempfile.NamedTemporaryFile(suffix=".h5ad", delete=False) as tmp_file:
temp_path = tmp_file.name
with urllib.request.urlopen(filename) as response:
shutil.copyfileobj(response, tmp_file)
return read_h5ad(tmp_file)
try:
return read_h5ad(temp_path, backed=backed)
finally:
try:
os.unlink(temp_path) # clean up the temp file
except OSError:
pass # if there's an error during cleanup, continue


def read_h5ad_local(filename: str) -> AnnData:
def read_h5ad_local(filename: str, backed: backed_mode_type = backed_mode_default) -> AnnData:
r"""
Read ``.h5ad``-formatted hdf5 file from the local disk.

Args:
filename: Path to the local data file.
backed: See :func:`anndata.read_h5ad` for details on backed mode.
['r', True] will load in backed mode instead of fully loading into memory.
[False, None] will use in-memory mode.
"""
if not filename.startswith("file:"):
raise ValueError("The filename must start with 'file:' protocol name.")
filename = re.sub(r"^file://?", "", filename)
return read_h5ad(filename)
return read_h5ad(filename, backed=backed)


def read_h5ad_file(filename: str, **kwargs) -> AnnData:
def read_h5ad_file(filename: str, backed: backed_mode_type = backed_mode_default, **kwargs) -> AnnData:
r"""
Read ``.h5ad``-formatted hdf5 file from a filename.

Args:
filename: Path to the data file.
backed: See :func:`anndata.read_h5ad` for details on backed mode.
['r', True] will load in backed mode instead of fully loading into memory.
[False, None] will use in-memory mode.
"""
if filename.startswith("gs:"):
return read_h5ad_gcs(filename, **kwargs)
Expand All @@ -89,4 +131,4 @@ def read_h5ad_file(filename: str, **kwargs) -> AnnData:
if any(filename.startswith(scheme) for scheme in url_schemes):
return read_h5ad_url(filename)

return read_h5ad(filename)
return read_h5ad(filename, backed=backed)
83 changes: 73 additions & 10 deletions tests/dataloader/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from pathlib import Path
from typing import Any

import lightning.pytorch as pl
import pytest
Expand Down Expand Up @@ -176,14 +177,15 @@ def _check_transform_lists_match(iterable_modules1, iterable_modules2, message):
"accelerator",
["cpu", pytest.param("gpu", marks=pytest.mark.skipif(not USE_CUDA, reason="requires_cuda"))],
)
@pytest.mark.parametrize("batch_size", [50, None])
def test_datamodule(tmp_path: Path, batch_size: int | None, accelerator: str) -> None:
@pytest.mark.parametrize("batch_size", [50, 100])
def test_datamodule(tmp_path: Path, batch_size: int, accelerator: str) -> None:
dadc = DistributedAnnDataCollection(
filenames="https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_0.h5ad",
shard_size=100,
)
datamodule = CellariumAnnDataDataModule(
DistributedAnnDataCollection(
filenames="https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_0.h5ad",
shard_size=100,
),
batch_size=100,
dadc=dadc,
batch_size=batch_size,
batch_keys={
"x_ng": AnnDataField(attr="X", convert_fn=densify),
},
Expand All @@ -193,15 +195,76 @@ def test_datamodule(tmp_path: Path, batch_size: int | None, accelerator: str) ->
trainer.fit(module, datamodule)

ckpt_path = str(tmp_path / "lightning_logs/version_0/checkpoints/epoch=0-step=1.ckpt")
adata = datamodule.dadc.adatas[0].adata
kwargs = {"dadc": adata}
kwargs: dict[str, Any] = {"dadc": dadc}
if batch_size is not None:
kwargs["batch_size"] = batch_size
loaded_datamodule = CellariumAnnDataDataModule.load_from_checkpoint(ckpt_path, **kwargs)

assert loaded_datamodule.batch_keys == datamodule.batch_keys
assert loaded_datamodule.batch_size == batch_size or datamodule.batch_size
assert loaded_datamodule.dadc is adata
assert loaded_datamodule.dadc is dadc


@pytest.fixture
def fake_massive_dense_h5ad(tmp_path: Path) -> Path:
import h5py
import numpy as np

# Create a dataset that CLAIMS to be ~40GB but uses almost no disk space
n_obs = 2_000_000 # 2 million cells
n_vars = 5_000 # 5k genes

h5ad_path = tmp_path / "massive_fake.h5ad"

with h5py.File(h5ad_path, "w") as f:
# Create X dataset with claimed huge size but minimal actual storage
# Using fillvalue=0.0 with chunking - chunks are only allocated when written to
f.create_dataset(
"X",
shape=(n_obs, n_vars),
dtype=np.float32,
fillvalue=0.0,
chunks=True, # Enable chunking so not all data needs to be stored
compression=None, # No compression to keep it simple
)

# Create minimal obs metadata - just the index is required
obs_group = f.create_group("obs")
# Create a small obs index but tell HDF5 it could expand to n_obs
obs_index_data = np.array([f"CELL_{i:07d}".encode("utf-8") for i in range(n_obs)])
obs_group.create_dataset("_index", data=obs_index_data, maxshape=(n_obs,), dtype="S12")

# Create minimal var metadata - just the index is required
var_group = f.create_group("var")
var_index_data = np.array([f"GENE_{i:05d}".encode("utf-8") for i in range(n_vars)])
var_group.create_dataset("_index", data=var_index_data, dtype="S10")

# Set minimal h5ad format attributes that anndata expects
f.attrs["encoding-type"] = "anndata"
f.attrs["encoding-version"] = "0.1.0"

return h5ad_path


def test_datamodule_massive_h5ad_backed(tmp_path: Path, fake_massive_dense_h5ad: Path) -> None:
# try training using a massive (faked) h5ad file which should only succeed if backed mode works
dadc = DistributedAnnDataCollection(
filenames=str(fake_massive_dense_h5ad), # Use full path instead of just name
shard_size=2_000_000,
backed=True,
)
datamodule = CellariumAnnDataDataModule(
dadc=dadc,
batch_size=100,
batch_keys={
"x_ng": AnnDataField(attr="X", convert_fn=None), # already dense
},
)
module = CellariumModule(model=BoringModel())
trainer = pl.Trainer(accelerator="cpu", devices=1, max_steps=1, default_root_dir=tmp_path)
trainer.fit(module, datamodule)
# the idea is if this can run without a memory overflow, backed mode is implemented correctly
# we have separately verified that backed=False will crash due to 40GB memory use


@pytest.mark.parametrize(
Expand Down
Loading
Loading