Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions src/anndata/experimental/pytorch/_annloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from scipy.sparse import issparse

from ..._core.anndata import AnnData
from ...compat import old_positionals
from ..._warnings import warn
from ...compat import Empty, old_positionals
from ..multi_files._anncollection import AnnCollection, _ConcatViewMixin

if find_spec("torch") or TYPE_CHECKING:
Expand All @@ -22,6 +23,7 @@

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Sequence
from typing import Literal

from scipy.sparse import spmatrix

Expand Down Expand Up @@ -70,21 +72,19 @@ def __len__(self) -> int:
return length


# maybe replace use_cuda with explicit device option
def default_converter(arr: Array, *, use_cuda: bool, pin_memory: bool):
def default_converter(
arr: Array, *, device: Literal["cpu", "cuda", "mps"] = "cpu", pin_memory: bool
):
if isinstance(arr, torch.Tensor):
if use_cuda:
arr = arr.cuda()
elif pin_memory:
arr = arr.to(device)
if device == "cpu" and pin_memory:
arr = arr.pin_memory()
elif arr.dtype.name != "category" and np.issubdtype(arr.dtype, np.number):
if issparse(arr):
arr = arr.toarray()
if use_cuda:
arr = torch.tensor(arr, device="cuda")
else:
arr = torch.tensor(arr)
arr = arr.pin_memory() if pin_memory else arr
arr = torch.tensor(arr, device=device)
if device == "cpu" and pin_memory:
arr = arr.pin_memory()
return arr


Expand Down Expand Up @@ -135,12 +135,15 @@ class AnnLoader(DataLoader):
Set to `True` to have the data reshuffled at every epoch.
use_default_converter
Use the default converter to convert arrays to pytorch tensors, transfer to
the default cuda device (if `use_cuda=True`), do memory pinning (if `pin_memory=True`).
the specified device (if `device` is set), do memory pinning (if `pin_memory=True`).
If you pass an AnnCollection object with prespecified converters, the default converter
won't overwrite these converters but will be applied on top of them.
device
Transfer pytorch tensors to the specified device after conversion.
Only works if `use_default_converter=True`.
use_cuda
Transfer pytorch tensors to the default cuda device after conversion.
Only works if `use_default_converter=True`
.. deprecated::
Use `device='cuda'` instead.
**kwargs
Arguments for PyTorch DataLoader. If `adatas` is not an `AnnCollection` object, then also
arguments for `AnnCollection` initialization.
Expand All @@ -154,9 +157,23 @@ def __init__(
batch_size: int = 1,
shuffle: bool = False,
use_default_converter: bool = True,
use_cuda: bool = False,
device: Literal["cpu", "cuda", "mps"] = "cpu",
use_cuda: bool = Empty.TOKEN,
**kwargs,
):
if use_cuda is not Empty.TOKEN:
if device != "cpu":
msg = (
"Cannot specify both 'device' and 'use_cuda'. Use 'device' instead."
)
raise ValueError(msg)
warn(
"'use_cuda' is deprecated, use 'device' instead. "
"Pass device='cuda' instead of use_cuda=True.",
FutureWarning,
)
device = "cuda" if use_cuda else "cpu"

if isinstance(adatas, AnnData):
adatas = [adatas]

Expand Down Expand Up @@ -191,7 +208,7 @@ def __init__(
if use_default_converter:
pin_memory = kwargs.pop("pin_memory", False)
_converter = partial(
default_converter, use_cuda=use_cuda, pin_memory=pin_memory
default_converter, device=device, pin_memory=pin_memory
)
dataset.convert = _convert_on_top(
dataset.convert, _converter, dict(dataset.attrs_keys, X=[])
Expand Down
42 changes: 42 additions & 0 deletions tests/test_annloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

import numpy as np
import pytest

import anndata as ad

pytest.importorskip("torch")

from anndata.experimental.pytorch import AnnLoader


@pytest.fixture
def adata():
return ad.AnnData(X=np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]))


def test_annloader_default_device(adata):
"""AnnLoader with default device='cpu' produces CPU tensors."""
loader = AnnLoader(adata, batch_size=2)
batch = next(iter(loader))
assert batch.X.device.type == "cpu"


def test_annloader_explicit_cpu_device(adata):
"""AnnLoader with explicit device='cpu' produces CPU tensors."""
loader = AnnLoader(adata, batch_size=2, device="cpu")
batch = next(iter(loader))
assert batch.X.device.type == "cpu"


def test_annloader_use_cuda_deprecation_warning(adata):
"""Passing use_cuda emits a FutureWarning."""
with pytest.warns(FutureWarning, match="use_cuda.*deprecated"):
# use_cuda=False should still emit warning (parameter was explicitly passed)
AnnLoader(adata, batch_size=2, use_cuda=False)


def test_annloader_use_cuda_and_device_conflict(adata):
"""Passing both use_cuda and device raises ValueError."""
with pytest.raises(ValueError, match="Cannot specify both"):
AnnLoader(adata, batch_size=2, use_cuda=True, device="cuda")
Loading