diff --git a/.github/workflows/test_linux_benchmark.yml b/.github/workflows/test_linux_benchmark.yml new file mode 100644 index 0000000000..7f6bb23ce8 --- /dev/null +++ b/.github/workflows/test_linux_benchmark.yml @@ -0,0 +1,75 @@ +name: test (benchmark) + +on: + pull_request: + branches: [main, "[0-9]+.[0-9]+.x"] + types: [labeled, synchronize, opened] + schedule: + - cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + test: + # if PR has label "benchmark tests" or "all tests" or if scheduled or manually triggered or on push + if: >- + ( + contains(github.event.pull_request.labels.*.name, 'benchmark tests') || + contains(github.event.pull_request.labels.*.name, 'all tests') || + (contains(github.event_name, 'schedule') && github.repository == 'scverse/scvi-tools') || + contains(github.event_name, 'workflow_dispatch') || + contains(github.event_name, 'push') + ) + + runs-on: ${{ matrix.os }} + + defaults: + run: + shell: bash -e {0} # -e to fail on error + + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: ["3.13"] + + permissions: + id-token: write + + name: unit + + env: + OS: ${{ matrix.os }} + PYTHON: ${{ matrix.python }} + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + cache: "pip" + cache-dependency-path: "**/pyproject.toml" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip wheel uv + python -m uv pip install --system "scvi-tools[tests] @ ." + + - name: Run pytest + env: + MPLBACKEND: agg + PLATFORM: ${{ matrix.os }} + DISPLAY: :42 + COLUMNS: 120 + run: | + coverage run -m pytest -v --color=yes --benchmark + coverage report + + - uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + flags: benchmark diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f7c45ac5fa..9019337ea9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,7 @@ fail_fast: false default_language_version: python: python3 + node: 20.20.2 default_stages: - pre-commit - pre-push diff --git a/CHANGELOG.md b/CHANGELOG.md index 79fc9ec3af..a3719e658a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ to [Semantic Versioning]. The full commit history is available in the [commit lo - Add support for Python 3.14, {pr}`3563`. - Add support for Pandas3, {pr}`3638`. +- Add graph-aware dataloading support for {class}`scvi.external.RESOLVI` with `torch_geometric` {pr}`37XX`. #### Fixed diff --git a/pyproject.toml b/pyproject.toml index 423018afba..e92b90f41e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,7 @@ docs = [ "myst-nb", "sphinx-autodoc-typehints", ] -docsbuild = ["scvi-tools[docs,autotune,hub,jax]","mlx"] +docsbuild = ["scvi-tools[docs,autotune,hub,jax,diagvi]","mlx"] # scvi.autotune #TODO remove ray[tune] constraint once solved autotune = ["hyperopt>=0.2", "ray[tune]; python_version < '3.14'", "scib-metrics", "muon"] @@ -97,9 +97,11 @@ jax = ["jax<0.10.0", "jaxlib", "optax", "numpyro", "flax"] #TODO: unpin once it dataloaders = ["lamindb>=1.12.1", "cellxgene-census", "tiledbsoma", "tiledbsoma_ml", "torchdata"] # for mlflow mlflow = ["mlflow","psutil","GPUtil","nvidia-ml-py"] +# for diagvi +diagvi = ["torch_geometric", "geomloss"] optional = [ - "scvi-tools[autotune,mlflow,hub,jax,file_sharing,regseq,parallel,interpretability]", + "scvi-tools[autotune,mlflow,hub,jax,file_sharing,regseq,parallel,interpretability,diagvi]", "igraph","leidenalg","pynndescent", ] tutorials = [ @@ -131,6 +133,7 @@ omit = [ testpaths = ["tests"] xfail_strict = true markers = [ + "benchmark: mark benchmark/performance tests", "internet: mark tests that requires internet access", "optional: mark optional tests, usually take more time", "private: mark tests that uses private keys, like HF", @@ -140,6 +143,7 @@ markers = [ "dataloader: mark tests that are used to check data loaders", "jax: mark test as jax related", "mlflow: mark test for mlflow", + "diagvi: mark test for diagvi and torch_geometric based models", ] [tool.ruff] diff --git a/src/scvi/dataloaders/__init__.py b/src/scvi/dataloaders/__init__.py index 0ebf821423..94f336e475 100644 --- a/src/scvi/dataloaders/__init__.py +++ b/src/scvi/dataloaders/__init__.py @@ -10,6 +10,7 @@ DeviceBackedDataSplitter, SemiSupervisedDataSplitter, ) +from ._graph_dataloader import GraphDataLoader, GraphDataSplitter from ._samplers import BatchDistributedSampler from ._semi_dataloader import SemiSupervisedDataLoader @@ -19,10 +20,12 @@ "CollectionAdapter", "ConcatDataLoader", "DeviceBackedDataSplitter", - "SemiSupervisedDataLoader", "DataSplitter", - "SemiSupervisedDataSplitter", + "GraphDataLoader", + "GraphDataSplitter", "BatchDistributedSampler", "MappedCollectionDataModule", + "SemiSupervisedDataLoader", + "SemiSupervisedDataSplitter", "TileDBDataModule", ] diff --git a/src/scvi/dataloaders/_graph_dataloader.py b/src/scvi/dataloaders/_graph_dataloader.py new file mode 100644 index 0000000000..a1d76002f9 --- /dev/null +++ b/src/scvi/dataloaders/_graph_dataloader.py @@ -0,0 +1,227 @@ +"""Graph-aware dataloaders for spatial single-cell models.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import torch +from torch.utils.data import default_convert + +from scvi import REGISTRY_KEYS +from scvi.data import AnnTorchDataset +from scvi.dataloaders._ann_dataloader import AnnDataLoader +from scvi.dataloaders._data_splitting import DataSplitter + +if TYPE_CHECKING: + from scvi.data import AnnDataManager + + +def _as_torch_tensor(array: np.ndarray | torch.Tensor) -> torch.Tensor: + """Convert AnnTorchDataset output to a torch tensor without changing sparse layout.""" + if isinstance(array, np.ndarray): + return torch.from_numpy(array) + return array + + +class _GraphBatchConverter: + """Convert an AnnTorchDataset batch into a PyG Data object.""" + + def __init__( + self, + full_adata_manager: AnnDataManager, + neighbor_indices_key: str, + edge_obsm_keys: list[str], + load_sparse_neighbor_tensor: bool, + load_neighbor_expression: bool, + ): + self.neighbor_indices_key = neighbor_indices_key + self.edge_obsm_keys = edge_obsm_keys + self.load_neighbor_expression = load_neighbor_expression + if load_neighbor_expression: + self._full_dataset = AnnTorchDataset( + full_adata_manager, + getitem_tensors=[REGISTRY_KEYS.X_KEY], + load_sparse_tensor=load_sparse_neighbor_tensor, + ) + + def __call__(self, batch: dict[str, np.ndarray | torch.Tensor]): + try: + from torch_geometric.data import Data + except ImportError as error: + raise ImportError( + "torch_geometric is required for GraphDataLoader. " + "Install it with: pip install torch_geometric" + ) from error + + batch = default_convert(batch) + ind_neighbors = batch[self.neighbor_indices_key].long() + n_obs, n_neighbors = ind_neighbors.shape + + x = _as_torch_tensor(batch[REGISTRY_KEYS.X_KEY]) + + center_idx = torch.arange(n_obs, dtype=torch.long).repeat_interleave(n_neighbors) + neighbor_idx = torch.arange(n_obs * n_neighbors, dtype=torch.long) + edge_index = torch.stack([center_idx, neighbor_idx], dim=0) + + edge_attrs = [] + for key in self.edge_obsm_keys: + vals = batch[key].float() + edge_attrs.append(vals.reshape(n_obs * n_neighbors, -1)) + edge_attr = torch.cat(edge_attrs, dim=1) if edge_attrs else None + + data_kwargs = dict(batch) + data_kwargs.update( + { + "x": x, + "edge_index": edge_index, + "edge_attr": edge_attr, + "distances_n": batch.get("distance_neighbor"), + } + ) + if self.load_neighbor_expression: + flat_neighbors = ind_neighbors.cpu().numpy().ravel() + data_kwargs["x_n"] = _as_torch_tensor( + self._full_dataset[flat_neighbors][REGISTRY_KEYS.X_KEY] + ) + return Data(**data_kwargs) + + +class GraphDataLoader(AnnDataLoader): + """DataLoader that yields mini-batches as :class:`torch_geometric.data.Data` objects. + + Each batch contains center cells and their pre-fetched spatial neighbors. Neighbor + expression is looked up from ``full_adata_manager`` so neighbors outside the current + train/validation/test split are intentionally allowed, matching existing RESOLVI behavior. + + Parameters + ---------- + adata_manager + :class:`~scvi.data.AnnDataManager` for the split being loaded. + full_adata_manager + :class:`~scvi.data.AnnDataManager` for all observations. Used for neighbor expression + lookup, including cross-split neighbors. + indices + Observation indices to load from ``adata_manager``. + neighbor_indices_key + Registry key containing neighbor indices, shape ``[N, K]``. + edge_obsm_keys + Registry keys to flatten and concatenate into ``edge_attr``. Each key must have shape + ``[N, K]`` or ``[N, K, D]``. Defaults to ``["distance_neighbor"]``. + load_sparse_neighbor_tensor + If ``True``, loads sparse neighbor expression as sparse torch tensors. This avoids + densifying neighbor expression on the CPU before device transfer. + load_neighbor_expression + If ``False``, omits ``x_n`` and leaves neighbor expression gathering to the model. This is + useful when a model keeps a device-resident expression cache. + **kwargs + Forwarded to :class:`~scvi.dataloaders.AnnDataLoader`. + """ + + def __init__( + self, + adata_manager: AnnDataManager, + full_adata_manager: AnnDataManager, + indices: list[int] | list[bool] | None = None, + neighbor_indices_key: str = "index_neighbor", + edge_obsm_keys: list[str] | None = None, + load_sparse_neighbor_tensor: bool = True, + load_neighbor_expression: bool = True, + **kwargs, + ): + if "collate_fn" in kwargs: + raise ValueError("GraphDataLoader uses its own collate_fn to build graph batches.") + if kwargs.pop("iter_ndarray", False): + raise ValueError("GraphDataLoader does not support `iter_ndarray=True`.") + + super().__init__(adata_manager, indices=indices, **kwargs) + self.neighbor_indices_key = neighbor_indices_key + self.edge_obsm_keys = ( + list(edge_obsm_keys) if edge_obsm_keys is not None else ["distance_neighbor"] + ) + self.load_sparse_neighbor_tensor = load_sparse_neighbor_tensor + self.load_neighbor_expression = load_neighbor_expression + self._graph_batch_converter = _GraphBatchConverter( + full_adata_manager, + neighbor_indices_key=self.neighbor_indices_key, + edge_obsm_keys=self.edge_obsm_keys, + load_sparse_neighbor_tensor=load_sparse_neighbor_tensor, + load_neighbor_expression=load_neighbor_expression, + ) + self.collate_fn = self._graph_batch_converter + + +class GraphDataSplitter(DataSplitter): + """DataSplitter that creates :class:`GraphDataLoader` instances. + + Parameters + ---------- + load_sparse_neighbor_tensor + Forwarded to :class:`GraphDataLoader`. + load_neighbor_expression + Forwarded to :class:`GraphDataLoader`. + """ + + def __init__( + self, + adata_manager: AnnDataManager, + neighbor_indices_key: str = "index_neighbor", + edge_obsm_keys: list[str] | None = None, + load_sparse_neighbor_tensor: bool = True, + load_neighbor_expression: bool = True, + **kwargs, + ): + super().__init__(adata_manager, **kwargs) + self.neighbor_indices_key = neighbor_indices_key + self.edge_obsm_keys = ( + list(edge_obsm_keys) if edge_obsm_keys is not None else ["distance_neighbor"] + ) + self.load_sparse_neighbor_tensor = load_sparse_neighbor_tensor + self.load_neighbor_expression = load_neighbor_expression + + def _make_graph_dataloader( + self, + indices: np.ndarray, + shuffle: bool, + drop_last: bool, + ) -> GraphDataLoader: + return GraphDataLoader( + self.adata_manager, + full_adata_manager=self.adata_manager, + indices=indices, + shuffle=shuffle, + drop_last=drop_last, + load_sparse_tensor=self.load_sparse_tensor, + pin_memory=self.pin_memory, + neighbor_indices_key=self.neighbor_indices_key, + edge_obsm_keys=self.edge_obsm_keys, + load_sparse_neighbor_tensor=self.load_sparse_neighbor_tensor, + load_neighbor_expression=self.load_neighbor_expression, + **self.data_loader_kwargs, + ) + + def train_dataloader(self) -> GraphDataLoader: + """Create graph train dataloader.""" + return self._make_graph_dataloader( + self.train_idx, + shuffle=True, + drop_last=self.drop_last, + ) + + def val_dataloader(self) -> GraphDataLoader | None: + """Create graph validation dataloader.""" + if len(self.val_idx) > 0: + return self._make_graph_dataloader( + self.val_idx, + shuffle=False, + drop_last=False, + ) + + def test_dataloader(self) -> GraphDataLoader | None: + """Create graph test dataloader.""" + if len(self.test_idx) > 0: + return self._make_graph_dataloader( + self.test_idx, + shuffle=False, + drop_last=False, + ) diff --git a/src/scvi/external/resolvi/_model.py b/src/scvi/external/resolvi/_model.py index d4503f5c2b..57976bf83b 100644 --- a/src/scvi/external/resolvi/_model.py +++ b/src/scvi/external/resolvi/_model.py @@ -21,7 +21,7 @@ NumericalObsField, ObsmField, ) -from scvi.dataloaders import AnnTorchDataset +from scvi.dataloaders import AnnTorchDataset, GraphDataSplitter from scvi.model._utils import ( scrna_raw_counts_properties, ) @@ -93,6 +93,7 @@ class RESOLVI( """ _module_cls = RESOLVAE + _data_splitter_cls = GraphDataSplitter _block_parameters = [] def __init__( @@ -194,6 +195,8 @@ def train( n_epochs_kl_warmup: int | None = 20, plan_kwargs: dict | None = None, expose_params: list = (), + cache_neighbor_expression: bool | Literal["auto"] = "auto", + neighbor_expression_cache_max_bytes: int | None = 1_000_000_000, **kwargs, ): """ @@ -211,6 +214,14 @@ def train( List of parameters to train with `lr_extra` learning rate. batch_size Minibatch size to use during training. + cache_neighbor_expression + Whether to cache the full expression matrix densely and gather neighbor expression in + the RESOLVI module. ``"auto"`` enables the cache for graph dataloading only when the + estimated dense cache size is below ``neighbor_expression_cache_max_bytes``. Set this + to ``True`` to explicitly enable the cache for non-graph dataloaders. + neighbor_expression_cache_max_bytes + Maximum estimated dense expression cache size allowed in ``"auto"`` mode. Set to + ``None`` to disable the size guard. weight_decay weight decay regularization term for optimization eps @@ -267,11 +278,24 @@ def per_param_callable(module_name, param_name): } ) + datasplitter_kwargs = dict(kwargs.pop("datasplitter_kwargs", {}) or {}) + uses_graph_splitter = issubclass(self._data_splitter_cls, GraphDataSplitter) + cache_request = cache_neighbor_expression + if cache_request == "auto" and not uses_graph_splitter: + cache_request = False + cache_enabled = self.module.configure_neighbor_expression_cache( + cache=cache_request, + max_bytes=neighbor_expression_cache_max_bytes, + ) + if cache_enabled and uses_graph_splitter: + datasplitter_kwargs.setdefault("load_neighbor_expression", False) + super().train( max_epochs=max_epochs, train_size=1.0, plan_kwargs=plan_kwargs, batch_size=batch_size, + datasplitter_kwargs=datasplitter_kwargs, **kwargs, ) diff --git a/src/scvi/external/resolvi/_module.py b/src/scvi/external/resolvi/_module.py index 0ef8fe3a00..399b3797d0 100644 --- a/src/scvi/external/resolvi/_module.py +++ b/src/scvi/external/resolvi/_module.py @@ -160,6 +160,9 @@ def __init__( self.eps = torch.tensor(1e-6) self.encode_covariates = encode_covariates self.size_scaling = size_scaling + self._use_neighbor_expression_cache = False + self._neighbor_expression_cache = None + self._neighbor_expression_cache_max_bytes = None if self.dispersion == "gene": init_px_r = torch.full([n_input], 0.01) @@ -252,6 +255,57 @@ def __init__( var_eps=1e-6, ) + def _estimate_neighbor_expression_cache_bytes(self) -> int: + """Estimate dense expression cache size in bytes.""" + float32_size = torch.tensor([], dtype=torch.float32).element_size() + return len(self.expression_anntorchdata) * self.n_input * float32_size + + def configure_neighbor_expression_cache( + self, + cache: bool | str = "auto", + max_bytes: int | None = 1_000_000_000, + ) -> bool: + """Enable a dense expression cache used for device-side neighbor gathers.""" + if cache == "auto": + estimated_bytes = self._estimate_neighbor_expression_cache_bytes() + enabled = max_bytes is None or estimated_bytes <= max_bytes + elif isinstance(cache, bool): + enabled = cache + else: + raise ValueError("`cache` must be a bool or 'auto'.") + + self._use_neighbor_expression_cache = enabled + self._neighbor_expression_cache_max_bytes = max_bytes + if not enabled: + self.clear_neighbor_expression_cache() + return enabled + + def clear_neighbor_expression_cache(self) -> None: + """Drop the dense neighbor expression cache.""" + self._neighbor_expression_cache = None + + def _get_neighbor_expression_from_cache( + self, + ind_neighbors: torch.Tensor, + x: torch.Tensor, + ) -> torch.Tensor: + """Gather flattened neighbor expression rows from the dense cache.""" + dtype = x.dtype if x.is_floating_point() else torch.float32 + cache = self._neighbor_expression_cache + if cache is None or cache.device != x.device or cache.dtype != dtype: + full_x = self.expression_anntorchdata[np.arange(len(self.expression_anntorchdata))][ + REGISTRY_KEYS.X_KEY + ] + if isinstance(full_x, np.ndarray): + full_x = torch.from_numpy(full_x) + if full_x.layout is torch.sparse_csr or full_x.layout is torch.sparse_csc: + full_x = full_x.to_dense() + self._neighbor_expression_cache = full_x.to(device=x.device, dtype=dtype) + cache = self._neighbor_expression_cache + + flat_neighbors = ind_neighbors.reshape(-1).to(device=cache.device, dtype=torch.long) + return cache.index_select(0, flat_neighbors) + def _get_fn_args_from_batch(self, tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict: x = tensor_dict[REGISTRY_KEYS.X_KEY] y = tensor_dict[REGISTRY_KEYS.LABELS_KEY].long().ravel() @@ -261,12 +315,27 @@ def _get_fn_args_from_batch(self, tensor_dict: dict[str, torch.Tensor]) -> Itera cat_covs = tensor_dict[cat_key] if cat_key in tensor_dict.keys() else None ind_x = tensor_dict[REGISTRY_KEYS.INDICES_KEY].long().ravel() - distances_n = tensor_dict["distance_neighbor"] - ind_neighbors = tensor_dict["index_neighbor"].long() - - x_n = self.expression_anntorchdata[ind_neighbors.cpu().numpy().flatten(), :]["X"] - if isinstance(x_n, np.ndarray): - x_n = torch.from_numpy(x_n) + if self._use_neighbor_expression_cache and "index_neighbor" in tensor_dict: + distances_n = ( + tensor_dict["distances_n"] + if "distances_n" in tensor_dict + else tensor_dict["distance_neighbor"] + ) + ind_neighbors = tensor_dict["index_neighbor"].long() + x_n = self._get_neighbor_expression_from_cache(ind_neighbors, x) + elif "x_n" in tensor_dict: + x_n = tensor_dict["x_n"] + distances_n = ( + tensor_dict["distances_n"] + if "distances_n" in tensor_dict + else tensor_dict["distance_neighbor"] + ) + else: + distances_n = tensor_dict["distance_neighbor"] + ind_neighbors = tensor_dict["index_neighbor"].long() + x_n = self.expression_anntorchdata[ind_neighbors.cpu().numpy().flatten(), :]["X"] + if isinstance(x_n, np.ndarray): + x_n = torch.from_numpy(x_n) x_n = x_n.to(x.device) if x.layout is torch.sparse_csr or x.layout is torch.sparse_csc: @@ -1282,6 +1351,18 @@ def __init__( ) self._get_fn_args_from_batch = self._model._get_fn_args_from_batch + def configure_neighbor_expression_cache( + self, + cache: bool | str = "auto", + max_bytes: int | None = 1_000_000_000, + ) -> bool: + """Configure RESOLVI model-side neighbor expression caching.""" + return self._model.configure_neighbor_expression_cache(cache=cache, max_bytes=max_bytes) + + def clear_neighbor_expression_cache(self) -> None: + """Drop the RESOLVI model-side neighbor expression cache.""" + self._model.clear_neighbor_expression_cache() + @property def model(self): return self._model diff --git a/tests/conftest.py b/tests/conftest.py index f530fff417..03c152caad 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,6 +44,12 @@ def pytest_addoption(parser): default=False, help="Run tests that are optional.", ) + parser.addoption( + "--benchmark", + action="store_true", + default=False, + help="Run tests that are benchmarks.", + ) parser.addoption( "--jax", action="store_true", @@ -79,14 +85,31 @@ def pytest_addoption(parser): def pytest_configure(config): """Docstring for pytest_configure.""" config.addinivalue_line("markers", "optional: mark test as optional.") + config.addinivalue_line("markers", "benchmark: mark test for benchmark.") + config.addinivalue_line("markers", "internet: mark test as internet tests.") + config.addinivalue_line("markers", "private: mark test as private tests.") + config.addinivalue_line("markers", "multigpu: mark test as multigpu tests.") + config.addinivalue_line("markers", "autotune: mark test as autotune tests.") + config.addinivalue_line("markers", "custom dataloaders: mark test as custom dataloaders test.") + config.addinivalue_line("markers", "dataloader: mark test as dataloader tests.") + config.addinivalue_line("markers", "jax: mark test as jax tests.") + config.addinivalue_line("markers", "mlflow: mark test as mlflow tests.") + config.addinivalue_line("markers", "diagvi: mark test as diagvi tests.") def pytest_collection_modifyitems(config, items): """Docstring for pytest_collection_modifyitems.""" + run_benchmark = config.getoption("--benchmark") + + def benchmark_selected(item): + return run_benchmark and ("benchmark" in item.keywords) + run_internet = config.getoption("--internet-tests") skip_internet = pytest.mark.skip(reason="need --internet-tests option to run") skip_non_internet = pytest.mark.skip(reason="test not having a pytest.mark.internet decorator") for item in items: + if benchmark_selected(item): + continue # All tests marked with `pytest.mark.internet` get skipped unless # `--internet-tests` passed if not run_internet and ("internet" in item.keywords): @@ -103,6 +126,8 @@ def pytest_collection_modifyitems(config, items): reason="test not having a pytest.mark.dataloader decorator" ) for item in items: + if benchmark_selected(item): + continue # All tests marked with `pytest.mark.dataloader` get skipped unless # `--custom-dataloader-tests` passed if not run_custom_dataloader and ("dataloader" in item.keywords): @@ -116,6 +141,8 @@ def pytest_collection_modifyitems(config, items): skip_optional = pytest.mark.skip(reason="need --optional option to run") skip_non_optional = pytest.mark.skip(reason="test not having a pytest.mark.optional decorator") for item in items: + if benchmark_selected(item): + continue # All tests marked with `pytest.mark.optional` get skipped unless # `--optional` passed if not run_optional and ("optional" in item.keywords): @@ -128,6 +155,8 @@ def pytest_collection_modifyitems(config, items): skip_jax = pytest.mark.skip(reason="need --jax option to run") skip_non_jax = pytest.mark.skip(reason="test not having a pytest.mark.jax decorator") for item in items: + if benchmark_selected(item): + continue # All tests marked with `pytest.mark.jax` get skipped unless # `--jax` passed if not run_jax and ("jax" in item.keywords): @@ -140,6 +169,8 @@ def pytest_collection_modifyitems(config, items): skip_private = pytest.mark.skip(reason="need --private option to run") skip_non_private = pytest.mark.skip(reason="test not having a pytest.mark.private decorator") for item in items: + if benchmark_selected(item): + continue # All tests marked with `pytest.mark.private` get skipped unless # `--private` passed if not run_private and ("private" in item.keywords): @@ -152,6 +183,8 @@ def pytest_collection_modifyitems(config, items): skip_multigpu = pytest.mark.skip(reason="need --multigpu-tests option to run") skip_non_multigpu = pytest.mark.skip(reason="test not having a pytest.mark.multigpu decorator") for item in items: + if benchmark_selected(item): + continue # All tests marked with `pytest.mark.multigpu` get skipped unless # `--multigpu-tests` passed if not run_multigpu and ("multigpu" in item.keywords): @@ -164,6 +197,8 @@ def pytest_collection_modifyitems(config, items): skip_autotune = pytest.mark.skip(reason="need --autotune-tests option to run") skip_non_autotune = pytest.mark.skip(reason="test not having a pytest.mark.autotune decorator") for item in items: + if benchmark_selected(item): + continue # All tests marked with `pytest.mark.autotune` get skipped unless # `--autotune-tests` passed if not run_autotune and ("autotune" in item.keywords): @@ -176,6 +211,8 @@ def pytest_collection_modifyitems(config, items): skip_mlflow = pytest.mark.skip(reason="need --mlflow-tests option to run") skip_non_mlflow = pytest.mark.skip(reason="test not having a pytest.mark.mlflow decorator") for item in items: + if benchmark_selected(item): + continue # All tests marked with `pytest.mark.mlflow` get skipped unless # `--mlflow-tests` passed if not run_mlflow and ("mlflow" in item.keywords): @@ -184,6 +221,19 @@ def pytest_collection_modifyitems(config, items): elif run_mlflow and ("mlflow" not in item.keywords): item.add_marker(skip_non_mlflow) + skip_benchmark = pytest.mark.skip(reason="need --benchmark option to run") + skip_non_benchmark = pytest.mark.skip( + reason="test not having a pytest.mark.benchmark decorator" + ) + for item in items: + # All tests marked with `pytest.mark.benchmark` get skipped unless + # `--benchmark` passed + if not run_benchmark and ("benchmark" in item.keywords): + item.add_marker(skip_benchmark) + # Skip all tests not marked with `pytest.mark.benchmark` if `--benchmark` passed + elif run_benchmark and ("benchmark" not in item.keywords): + item.add_marker(skip_non_benchmark) + @pytest.fixture(scope="session") def save_path(tmp_path_factory): diff --git a/tests/dataloaders/test_graph_dataloader.py b/tests/dataloaders/test_graph_dataloader.py new file mode 100644 index 0000000000..da5c927b7e --- /dev/null +++ b/tests/dataloaders/test_graph_dataloader.py @@ -0,0 +1,309 @@ +"""Unit tests for GraphDataLoader.""" + +import numpy as np +import pytest +import scipy.sparse as sp +import torch + +from scvi.data import synthetic_iid +from scvi.external import RESOLVI + + +@pytest.fixture(scope="module") +def resolvi_adata(): + adata = synthetic_iid(generate_coordinates=True, n_regions=5) + adata.obsm["X_spatial"] = adata.obsm["coordinates"] + n_obs = adata.n_obs + n_neighbors = 10 + adata.obsm["index_neighbor"] = ( + np.arange(n_obs)[:, None] + np.arange(1, n_neighbors + 1)[None, :] + ) % n_obs + adata.obsm["distance_neighbor"] = np.ones((n_obs, n_neighbors), dtype=np.float32) + RESOLVI.setup_anndata(adata, prepare_data=False) + return adata + + +@pytest.fixture(scope="module") +def adata_manager(resolvi_adata): + return RESOLVI._get_most_recent_anndata_manager(resolvi_adata, required=True) + + +def test_graph_dataloader_yields_data_objects(adata_manager): + """Each batch must be a torch_geometric Data object, not a plain dict.""" + from torch_geometric.data import Data + + from scvi.dataloaders import GraphDataLoader + + dl = GraphDataLoader( + adata_manager, + full_adata_manager=adata_manager, + batch_size=32, + shuffle=False, + ) + batch = next(iter(dl)) + assert isinstance(batch, Data), f"Expected Data, got {type(batch)}" + + +def test_graph_dataloader_shapes(adata_manager): + """x: [N, G], x_n: [N*K, G], edge_index: [2, N*K], edge_attr: [N*K, 1].""" + from scvi.dataloaders import GraphDataLoader + + batch_size = 32 + dl = GraphDataLoader( + adata_manager, + full_adata_manager=adata_manager, + batch_size=batch_size, + shuffle=False, + ) + batch = next(iter(dl)) + + n_genes = adata_manager.adata.n_vars + K = adata_manager.adata.obsm["index_neighbor"].shape[1] + N = batch.x.shape[0] # may be < batch_size on last batch + + assert batch.x.shape == (N, n_genes), f"x shape wrong: {batch.x.shape}" + assert batch.x_n.shape == (N * K, n_genes), f"x_n shape wrong: {batch.x_n.shape}" + assert batch.edge_index.shape == (2, N * K), ( + f"edge_index shape wrong: {batch.edge_index.shape}" + ) + assert batch.edge_attr.shape[0] == N * K, f"edge_attr rows wrong: {batch.edge_attr.shape}" + + +def test_graph_dataloader_edge_index_correctness(adata_manager): + """edge_index[0] = center indices 0..N-1 repeated K times.""" + from scvi.dataloaders import GraphDataLoader + + dl = GraphDataLoader( + adata_manager, + full_adata_manager=adata_manager, + batch_size=16, + shuffle=False, + ) + batch = next(iter(dl)) + N = batch.x.shape[0] + K = adata_manager.adata.obsm["index_neighbor"].shape[1] + + expected_src = torch.arange(N).repeat_interleave(K) + torch.testing.assert_close(batch.edge_index[0], expected_src) + + expected_dst = torch.arange(N * K) + torch.testing.assert_close(batch.edge_index[1], expected_dst) + + +def test_graph_dataloader_edge_attr_from_keys(adata_manager): + """distance_neighbor in edge_obsm_keys appears in edge_attr as [N*K, 1].""" + from scvi.dataloaders import GraphDataLoader + + dl = GraphDataLoader( + adata_manager, + full_adata_manager=adata_manager, + batch_size=16, + shuffle=False, + edge_obsm_keys=["distance_neighbor"], + ) + batch = next(iter(dl)) + N = batch.x.shape[0] + K = adata_manager.adata.obsm["index_neighbor"].shape[1] + + # distance_neighbor is [N, K] flattened to [N*K, 1] + assert batch.edge_attr.shape == (N * K, 1) + assert batch.edge_attr.dtype == torch.float32 + + +def test_graph_dataloader_preserves_sparse_neighbor_expression_by_default(): + """Sparse neighbor expression should not be densified in the dataloader.""" + from scvi.dataloaders import GraphDataLoader + + adata = synthetic_iid(generate_coordinates=True, n_regions=5) + adata.X = sp.csr_matrix(adata.X) + n_obs = adata.n_obs + n_neighbors = 10 + adata.obsm["index_neighbor"] = ( + np.arange(n_obs)[:, None] + np.arange(1, n_neighbors + 1)[None, :] + ) % n_obs + adata.obsm["distance_neighbor"] = np.ones((n_obs, n_neighbors), dtype=np.float32) + + RESOLVI.setup_anndata(adata, prepare_data=False) + adata_manager = RESOLVI._get_most_recent_anndata_manager(adata, required=True) + + batch = next( + iter( + GraphDataLoader( + adata_manager, + full_adata_manager=adata_manager, + batch_size=8, + shuffle=False, + ) + ) + ) + + assert batch.x_n.layout is torch.sparse_csr + + dense_batch = next( + iter( + GraphDataLoader( + adata_manager, + full_adata_manager=adata_manager, + batch_size=8, + shuffle=False, + load_sparse_neighbor_tensor=False, + ) + ) + ) + assert dense_batch.x_n.layout is torch.strided + + +def test_graph_dataloader_cross_split_neighbors(adata_manager): + """Neighbor indices outside train split must resolve without error.""" + from scvi.dataloaders import GraphDataLoader + + n_obs = adata_manager.adata.n_obs + train_indices = np.arange(n_obs // 2) # first half only + + dl = GraphDataLoader( + adata_manager, + full_adata_manager=adata_manager, # full dataset — cross-split intentional + indices=train_indices, + batch_size=16, + shuffle=False, + ) + for batch in dl: + assert batch.x_n is not None + break + + +def test_graph_dataloader_builds_graph_batches_in_collate_fn(adata_manager): + """Graph batch construction should happen in the loader collate function.""" + from torch_geometric.data import Data + + from scvi.dataloaders import GraphDataLoader + + dl = GraphDataLoader( + adata_manager, + full_adata_manager=adata_manager, + batch_size=8, + shuffle=False, + ) + raw_batch = dl.dataset[np.arange(8)] + batch = dl.collate_fn(raw_batch) + + assert isinstance(batch, Data) + assert hasattr(batch, "x_n") + assert hasattr(batch, "edge_index") + + +def test_graph_dataloader_can_omit_neighbor_expression(adata_manager): + """Models with their own expression cache can request graph batches without x_n.""" + from scvi.dataloaders import GraphDataLoader + + batch = next( + iter( + GraphDataLoader( + adata_manager, + full_adata_manager=adata_manager, + batch_size=8, + shuffle=False, + load_neighbor_expression=False, + ) + ) + ) + + assert "x_n" not in batch + assert "index_neighbor" in batch + assert batch.edge_index.shape[1] == batch.index_neighbor.numel() + + +def test_graph_dataloader_missing_torch_geometric(adata_manager, monkeypatch): + """ImportError with install hint when torch_geometric is absent.""" + import builtins + import sys + + from scvi.dataloaders import GraphDataLoader + + real_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if name == "torch_geometric" or name.startswith("torch_geometric."): + raise ImportError("No module named 'torch_geometric'") + return real_import(name, *args, **kwargs) + + dl = GraphDataLoader( + adata_manager, + full_adata_manager=adata_manager, + batch_size=4, + shuffle=False, + ) + + # Remove torch_geometric from the module cache so our __import__ mock takes effect. + # Without this, Python returns the cached module and never calls __import__. + tg_keys = [ + k for k in sys.modules if k == "torch_geometric" or k.startswith("torch_geometric.") + ] + for k in tg_keys: + monkeypatch.delitem(sys.modules, k) + + monkeypatch.setattr(builtins, "__import__", mock_import) + with pytest.raises(ImportError, match="torch_geometric"): + next(iter(dl)) + + +def test_graph_datasplitter_returns_graph_dataloaders(adata_manager): + """train_dataloader and val_dataloader must return GraphDataLoader instances.""" + from scvi.dataloaders import GraphDataLoader, GraphDataSplitter + + splitter = GraphDataSplitter( + adata_manager, + batch_size=32, + train_size=0.8, + validation_size=0.1, + ) + splitter.setup() + + assert isinstance(splitter.train_dataloader(), GraphDataLoader) + assert isinstance(splitter.val_dataloader(), GraphDataLoader) + + +def test_graph_datasplitter_train_batches_are_data(adata_manager): + """Iterating train dataloader from GraphDataSplitter yields Data objects.""" + from torch_geometric.data import Data + + from scvi.dataloaders import GraphDataSplitter + + splitter = GraphDataSplitter(adata_manager, batch_size=32, train_size=0.8) + splitter.setup() + batch = next(iter(splitter.train_dataloader())) + + assert isinstance(batch, Data) + assert hasattr(batch, "x") + assert hasattr(batch, "x_n") + assert hasattr(batch, "edge_index") + + +def test_graph_datasplitter_custom_edge_keys(adata_manager): + """edge_obsm_keys forwarded to the underlying GraphDataLoader.""" + from scvi.dataloaders import GraphDataSplitter + + splitter = GraphDataSplitter( + adata_manager, + batch_size=16, + train_size=0.8, + edge_obsm_keys=["distance_neighbor"], + ) + splitter.setup() + + assert splitter.train_dataloader().edge_obsm_keys == ["distance_neighbor"] + + +def test_graph_datasplitter_forwards_neighbor_expression_flag(adata_manager): + """load_neighbor_expression should reach split dataloaders.""" + from scvi.dataloaders import GraphDataSplitter + + splitter = GraphDataSplitter( + adata_manager, + batch_size=16, + train_size=0.8, + load_neighbor_expression=False, + ) + splitter.setup() + + assert splitter.train_dataloader().load_neighbor_expression is False diff --git a/tests/external/mrvi_torch/test_torchmrvi_jax_equivalence.py b/tests/external/mrvi_torch/test_torchmrvi_jax_equivalence.py index f4f3a10e80..22438074c0 100644 --- a/tests/external/mrvi_torch/test_torchmrvi_jax_equivalence.py +++ b/tests/external/mrvi_torch/test_torchmrvi_jax_equivalence.py @@ -352,6 +352,7 @@ def _init_torch_module(jax_params): # ── tests ──────────────────────────────────────────────────────────────────── +@pytest.mark.benchmark @pytest.mark.jax def test_forward_pass_equivalence(): """Init JAX, transfer weights to PyTorch, compare all inference/generative/loss outputs.""" @@ -440,6 +441,7 @@ def _extract(val): _compare("total_loss", jax_loss.loss, torch_loss.loss, ATOL_ATTN) +@pytest.mark.benchmark @pytest.mark.jax def test_gradient_equivalence(): """Compare jax.grad vs loss.backward() for all parameters.""" diff --git a/tests/external/resolvi/test_resolvi_graph_dataloader.py b/tests/external/resolvi/test_resolvi_graph_dataloader.py new file mode 100644 index 0000000000..02cc8e4ad6 --- /dev/null +++ b/tests/external/resolvi/test_resolvi_graph_dataloader.py @@ -0,0 +1,465 @@ +"""GraphDataLoader integration tests for RESOLVI.""" + +import time +import warnings + +import numpy as np +import pytest +import scipy.sparse as sp +import torch + +from scvi.data import synthetic_iid +from scvi.external import RESOLVI + + +@pytest.fixture +def adata(): + adata = synthetic_iid(generate_coordinates=True, n_regions=5, n_proteins=10) + adata.obs["cell_area"] = np.random.default_rng(0).gamma(2.0, 1.0, size=adata.n_obs) + _setup_resolvi(RESOLVI, adata) + return adata + + +def _add_ring_neighbors(adata, n_neighbors: int = 10): + adata.obsm["X_spatial"] = adata.obsm["coordinates"] + n_obs = adata.n_obs + adata.obsm["index_neighbor"] = ( + np.arange(n_obs)[:, None] + np.arange(1, n_neighbors + 1)[None, :] + ) % n_obs + adata.obsm["distance_neighbor"] = np.tile( + np.arange(1, n_neighbors + 1, dtype=np.float32), + (n_obs, 1), + ) + + +def _setup_resolvi(cls, adata, **kwargs): + _add_ring_neighbors(adata) + cls.setup_anndata(adata, prepare_data=False, **kwargs) + + +def _resolvi_graph_cls(): + from scvi.dataloaders import GraphDataSplitter + + class RESOLVIGraph(RESOLVI): + _data_splitter_cls = GraphDataSplitter + + return RESOLVIGraph + + +def _resolvi_legacy_cls(): + from scvi.dataloaders import DataSplitter + + class RESOLVILegacy(RESOLVI): + _data_splitter_cls = DataSplitter + + return RESOLVILegacy + + +def _train_graph(model, max_epochs: int = 2, **kwargs): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + model.train( + max_epochs=max_epochs, + datasplitter_kwargs={"neighbor_indices_key": "index_neighbor"}, + **kwargs, + ) + + +def test_resolvi_get_fn_args_prefers_graph_batch_x_n(adata): + """RESOLVI must use pre-fetched GraphDataLoader neighbor expression when present.""" + from scvi.dataloaders import GraphDataLoader + + adata_manager = RESOLVI._get_most_recent_anndata_manager(adata, required=True) + model = RESOLVI(adata) + batch = next( + iter( + GraphDataLoader( + adata_manager, + full_adata_manager=adata_manager, + batch_size=8, + shuffle=False, + ) + ) + ) + + class FailingExpressionDataset: + def __getitem__(self, item): + raise AssertionError("fallback AnnTorchDataset should not be used") + + model.module.model.expression_anntorchdata = FailingExpressionDataset() + _, kwargs = model.module._get_fn_args_from_batch(batch) + + torch.testing.assert_close(kwargs["x_n"], batch.x_n.reshape(batch.x.shape[0], -1)) + torch.testing.assert_close(kwargs["distances_n"], batch.distances_n) + + +def test_resolvi_get_fn_args_accepts_sparse_graph_batch_x_n(): + """Sparse GraphDataLoader neighbor tensors should be densified inside RESOLVI.""" + from scvi.dataloaders import GraphDataLoader + + adata = synthetic_iid(generate_coordinates=True, n_regions=5, n_proteins=10) + adata.X = sp.csr_matrix(adata.X) + adata.obs["cell_area"] = np.random.default_rng(0).gamma(2.0, 1.0, size=adata.n_obs) + _setup_resolvi(RESOLVI, adata) + + adata_manager = RESOLVI._get_most_recent_anndata_manager(adata, required=True) + model = RESOLVI(adata) + batch = next( + iter( + GraphDataLoader( + adata_manager, + full_adata_manager=adata_manager, + batch_size=8, + shuffle=False, + ) + ) + ) + + assert batch.x_n.layout is torch.sparse_csr + _, kwargs = model.module._get_fn_args_from_batch(batch) + + torch.testing.assert_close( + kwargs["x_n"], + batch.x_n.to_dense().reshape(batch.x.shape[0], -1), + ) + + +def test_resolvi_get_fn_args_uses_neighbor_expression_cache_when_x_n_is_absent(adata): + """RESOLVI can gather neighbor expression from its cache instead of a graph batch x_n.""" + from scvi.dataloaders import GraphDataLoader + + adata_manager = RESOLVI._get_most_recent_anndata_manager(adata, required=True) + model = RESOLVI(adata) + model.module.configure_neighbor_expression_cache(cache=True) + batch = next( + iter( + GraphDataLoader( + adata_manager, + full_adata_manager=adata_manager, + batch_size=8, + shuffle=False, + load_neighbor_expression=False, + ) + ) + ) + + assert "x_n" not in batch + _, kwargs = model.module._get_fn_args_from_batch(batch) + assert model.module.model._neighbor_expression_cache is not None + + class FailingExpressionDataset: + def __getitem__(self, item): + raise AssertionError("fallback AnnTorchDataset should not be used") + + model.module.model.expression_anntorchdata = FailingExpressionDataset() + _, cached_kwargs = model.module._get_fn_args_from_batch(batch) + + torch.testing.assert_close(cached_kwargs["x_n"], kwargs["x_n"]) + torch.testing.assert_close(cached_kwargs["distances_n"], batch.distances_n) + + +def test_resolvi_graph_path_trains(adata): + """RESOLVI must train to completion when forced to use GraphDataLoader.""" + RESOLVIGraph = _resolvi_graph_cls() + + _setup_resolvi(RESOLVIGraph, adata) + model = RESOLVIGraph(adata) + _train_graph(model) + assert model.module.model._neighbor_expression_cache is not None + + latent = model.get_latent_representation() + assert latent.shape == (adata.n_obs, model.module.n_latent) + + model = RESOLVIGraph(adata, dispersion="gene-batch") + _train_graph(model) + + +def test_resolvi_auto_neighbor_expression_cache_is_graph_splitter_scoped(adata): + """Auto cache is enabled for graph splitters and left off for legacy splitters.""" + RESOLVILegacy = _resolvi_legacy_cls() + + _setup_resolvi(RESOLVILegacy, adata) + model = RESOLVILegacy(adata) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + model.train(max_epochs=1, enable_progress_bar=False, logger=False) + + assert model.module.model._use_neighbor_expression_cache is False + assert model.module.model._neighbor_expression_cache is None + + +def test_resolvi_uses_graph_datasplitter_by_default(): + """RESOLVI should opt into GraphDataSplitter without a test-only subclass.""" + from scvi.dataloaders import GraphDataSplitter + + assert RESOLVI._data_splitter_cls is GraphDataSplitter + + +def test_resolvi_graph_train_size_factor(adata): + """GraphDataLoader path supports RESOLVI size-factor training modes.""" + RESOLVIGraph = _resolvi_graph_cls() + + _setup_resolvi( + RESOLVIGraph, + adata, + batch_key="batch", + size_factor_key="cell_area", + ) + model = RESOLVIGraph(adata, size_scaling=True) + _train_graph(model) + + _setup_resolvi(RESOLVIGraph, adata, size_factor_key="cell_area") + model = RESOLVIGraph(adata, size_scaling=False) + _train_graph(model) + + +@pytest.mark.optional +def test_resolvi_graph_save_load(adata, tmp_path): + """GraphDataLoader path preserves legacy save/load behavior.""" + RESOLVIGraph = _resolvi_graph_cls() + + _setup_resolvi(RESOLVIGraph, adata) + model = RESOLVIGraph(adata) + _train_graph(model) + hist_elbo = model.history_["elbo_train"] + latent = model.get_latent_representation() + assert latent.shape == (adata.n_obs, model.module.n_latent) + model.differential_expression(groupby="labels") + model.differential_expression(groupby="labels", weights="importance") + + save_path = str(tmp_path / "test_resolvi_graph") + model.save(save_path, save_anndata=True, overwrite=True) + model2 = model.load(save_path) + np.testing.assert_array_equal(model2.history_["elbo_train"], hist_elbo) + latent2 = model2.get_latent_representation() + assert np.allclose(latent, latent2) + model.load_query_data(reference_model=save_path, adata=adata) + + +@pytest.mark.optional +def test_resolvi_graph_downstream(adata, tmp_path): + """GraphDataLoader path covers legacy downstream RESOLVI APIs.""" + RESOLVIGraph = _resolvi_graph_cls() + + _setup_resolvi(RESOLVIGraph, adata, size_factor_key="cell_area") + model = RESOLVIGraph(adata) + _train_graph(model) + latent = model.get_latent_representation() + assert latent.shape == (adata.n_obs, model.module.n_latent) + _ = model.get_normalized_expression(n_samples=31, library_size=10000) + _ = model.get_normalized_expression_importance(n_samples=30, library_size=10000) + _ = model.get_normalized_expression_importance(n_samples=30, size_scaling=True) + model.differential_expression(groupby="labels") + model.differential_expression(groupby="labels", weights="importance") + model.differential_expression(groupby="labels", weights="importance", size_scaling=True) + model.sample_posterior( + model=model.module.model_residuals, + num_samples=30, + return_samples=False, + return_sites=None, + batch_size=1000, + ) + model.sample_posterior( + model=model.module.model_residuals, + num_samples=30, + return_samples=False, + batch_size=1000, + ) + + model.load_query_data(reference_model=model, adata=adata) + save_path = str(tmp_path / "test_resolvi_graph") + model.save(save_path, save_anndata=True, overwrite=True) + model_query = model.load_query_data(reference_model=save_path, adata=adata) + _train_graph(model_query) + + +def test_resolvi_graph_downstream_size_scaling(adata, tmp_path): + """GraphDataLoader path covers downstream APIs with size scaling enabled.""" + RESOLVIGraph = _resolvi_graph_cls() + + _setup_resolvi(RESOLVIGraph, adata, size_factor_key="cell_area") + model = RESOLVIGraph(adata, size_scaling=True) + _train_graph(model) + latent = model.get_latent_representation() + assert latent.shape == (adata.n_obs, model.module.n_latent) + _ = model.get_normalized_expression(n_samples=31, library_size=10000) + _ = model.get_normalized_expression_importance(n_samples=30, library_size=10000) + _ = model.get_normalized_expression_importance(n_samples=30, size_scaling=True) + model.differential_expression(groupby="labels") + model.differential_expression(groupby="labels", weights="importance") + model.differential_expression(groupby="labels", weights="importance", size_scaling=True) + model.sample_posterior( + model=model.module.model_residuals, + num_samples=30, + return_samples=False, + return_sites=None, + batch_size=1000, + ) + model.sample_posterior( + model=model.module.model_residuals, + num_samples=30, + return_samples=False, + batch_size=1000, + ) + + model.load_query_data(reference_model=model, adata=adata) + save_path = str(tmp_path / "test_resolvi_graph") + model.save(save_path, save_anndata=True, overwrite=True) + model_query = model.load_query_data(reference_model=save_path, adata=adata) + _train_graph(model_query) + + +@pytest.mark.optional +def test_resolvi_graph_semisupervised(adata): + """GraphDataLoader path supports semisupervised RESOLVI APIs.""" + RESOLVIGraph = _resolvi_graph_cls() + + _setup_resolvi(RESOLVIGraph, adata, labels_key="labels") + model = RESOLVIGraph(adata, semisupervised=True) + _train_graph(model) + model.differential_niche_abundance( + batch_size=30, + groupby="batch", + neighbor_key="index_neighbor", + ) + pred = model.predict(soft=True) + assert pred.shape == (adata.n_obs, model.summary_stats.n_labels - 1) + pred = model.predict(soft=False) + assert pred.shape == (adata.n_obs,) + + +def test_resolvi_graph_scarches(adata): + """GraphDataLoader path preserves legacy scArches query workflow.""" + RESOLVIGraph = _resolvi_graph_cls() + + adata.obs["hemisphere"] = ["right" if x > 0 else "left" for x in adata.obsm["X_spatial"][:, 0]] + ref_adata = adata[adata.obs["hemisphere"] == "left"].copy() + query_adata = adata[adata.obs["hemisphere"] == "right"].copy() + + _setup_resolvi(RESOLVIGraph, ref_adata, labels_key="labels") + model = RESOLVIGraph(ref_adata, semisupervised=True) + _train_graph(model) + + ref_adata.obsm["resolvi_celltypes"] = model.predict(ref_adata, num_samples=3, soft=True) + ref_adata.obs["resolvi_predicted"] = ref_adata.obsm["resolvi_celltypes"].idxmax(axis=1) + ref_adata.obsm["X_resolVI"] = model.get_latent_representation(ref_adata) + + query_adata.obs["predicted_celltype"] = "unknown" + query_adata.obs_names = [f"query_{i}" for i in query_adata.obs_names] + _add_ring_neighbors(query_adata) + + model.prepare_query_anndata(query_adata, reference_model=model) + query_resolvi = model.load_query_data(query_adata, reference_model=model) + + _train_graph(query_resolvi, max_epochs=1) + + query_adata.obs["resolvi_predicted"] = query_resolvi.predict( + query_adata, + num_samples=3, + soft=False, + ) + query_adata.obsm["X_resolVI"] = query_resolvi.get_latent_representation(query_adata) + + +@pytest.mark.parametrize("weights", ["importance", "uniform"]) +@pytest.mark.parametrize("n_samples", [1, 3]) +@pytest.mark.parametrize("downsample_counts", [True, False]) +def test_resolvi_graph_differential_expression( + adata, + weights: str, + n_samples: int, + downsample_counts: bool, +): + """GraphDataLoader path supports RESOLVI differential expression settings.""" + RESOLVIGraph = _resolvi_graph_cls() + + _setup_resolvi(RESOLVIGraph, adata) + model = RESOLVIGraph(adata, downsample_counts=downsample_counts) + _train_graph(model, max_epochs=1) + model.differential_expression(groupby="labels", weights=weights, n_samples=n_samples) + + +@pytest.mark.benchmark +def test_resolvi_dataloader_speed_comparison(adata): + """Side-by-side wall-clock comparison: AnnDataLoader vs GraphDataLoader.""" + RESOLVILegacy = _resolvi_legacy_cls() + RESOLVIGraph = _resolvi_graph_cls() + + n_epochs = 5 + + _setup_resolvi(RESOLVILegacy, adata) + model_ann = RESOLVILegacy(adata) + t0 = time.perf_counter() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + model_ann.train(max_epochs=n_epochs) + t_ann = time.perf_counter() - t0 + + _setup_resolvi(RESOLVIGraph, adata) + model_graph = RESOLVIGraph(adata) + t0 = time.perf_counter() + _train_graph(model_graph, max_epochs=n_epochs) + t_graph = time.perf_counter() - t0 + + print(f"\nAnnDataLoader: {t_ann:.2f}s total ({t_ann / n_epochs:.3f}s/epoch)") + print(f"GraphDataLoader: {t_graph:.2f}s total ({t_graph / n_epochs:.3f}s/epoch)") + print(f"Ratio (graph/ann): {t_graph / t_ann:.2f}x") + + assert t_graph / t_ann < 3.0, ( + f"GraphDataLoader is {t_graph / t_ann:.1f}x slower than AnnDataLoader" + ) + + +@pytest.mark.benchmark +def test_resolvi_elbo_comparable_between_paths(adata): + """Both paths should reach similar final ELBO after training.""" + RESOLVILegacy = _resolvi_legacy_cls() + RESOLVIGraph = _resolvi_graph_cls() + + n_epochs = 10 + + _setup_resolvi(RESOLVILegacy, adata) + model_ann = RESOLVILegacy(adata) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + model_ann.train(max_epochs=n_epochs) + elbo_ann = model_ann.history_["elbo_train"].iloc[-1].values[0] + + _setup_resolvi(RESOLVIGraph, adata) + model_graph = RESOLVIGraph(adata) + _train_graph(model_graph, max_epochs=n_epochs) + elbo_graph = model_graph.history_["elbo_train"].iloc[-1].values[0] + + print(f"\nFinal ELBO - AnnDataLoader: {elbo_ann:.2f} GraphDataLoader: {elbo_graph:.2f}") + + assert np.isfinite(elbo_graph), "GraphDataLoader ELBO is not finite" + assert abs(elbo_graph - elbo_ann) / (abs(elbo_ann) + 1e-8) < 0.5, ( + f"ELBO diverged: ann={elbo_ann:.2f} graph={elbo_graph:.2f}" + ) + + +@pytest.mark.benchmark +def test_resolvi_graph_elbo_decreases(adata): + """ELBO must decrease over training with GraphDataLoader.""" + RESOLVIGraph = _resolvi_graph_cls() + + _setup_resolvi(RESOLVIGraph, adata) + model = RESOLVIGraph(adata) + _train_graph(model, max_epochs=10) + + history = model.history_["elbo_train"] + assert history.iloc[-1].values[0] < history.iloc[0].values[0], ( + "ELBO did not decrease with GraphDataLoader" + ) + + +@pytest.mark.benchmark +def test_resolvi_latent_shape_graph_path(adata): + """get_latent_representation() must return (n_obs, n_latent) with GraphDataLoader.""" + RESOLVIGraph = _resolvi_graph_cls() + + _setup_resolvi(RESOLVIGraph, adata) + model = RESOLVIGraph(adata) + _train_graph(model, max_epochs=3) + latent = model.get_latent_representation() + assert latent.shape == (adata.n_obs, model.module.n_latent)