diff --git a/docs/api.md b/docs/api.md index 2bb76940d..1de1e81c2 100644 --- a/docs/api.md +++ b/docs/api.md @@ -118,6 +118,20 @@ Writing formats that cannot represent all aspects of {class}`AnnData` objects. AnnData.write_loom ``` +(utilities-api)= + +## Utilities + +Helper functions used internationally or for reshaping and aligng `AnnData` objects. Can be useful for cusotm workflows or edge cases. + +```{eval-rst} +.. autosummary:: + :toctree: generated/ + + utils.adapt_vars_like + +``` + (experimental-api)= ## Experimental API diff --git a/src/anndata/__init__.py b/src/anndata/__init__.py index 5925837f6..a77b00183 100644 --- a/src/anndata/__init__.py +++ b/src/anndata/__init__.py @@ -62,4 +62,5 @@ def __getattr__(attr_name: str) -> Any: "settings", "types", "typing", + "utils", ] diff --git a/src/anndata/typing.py b/src/anndata/typing.py index 25e279248..6264fc67f 100644 --- a/src/anndata/typing.py +++ b/src/anndata/typing.py @@ -23,7 +23,7 @@ from .compat import Index as _Index if TYPE_CHECKING: - from typing import TypeAlias + from typing import AxisStorable, TypeAlias __all__ = ["AxisStorable", "Index", "RWAble"] diff --git a/src/anndata/utils.py b/src/anndata/utils.py index 55e6292b4..54442dda7 100644 --- a/src/anndata/utils.py +++ b/src/anndata/utils.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: from collections.abc import Iterable, Mapping, Sequence - from typing import Any, Literal + from typing import Any, AxisStorable, Literal logger = get_logger(__name__) @@ -449,3 +449,80 @@ def module_get_attr_redirect( return getattr(mod, new_path) msg = f"module {full_old_module_path} has no attribute {attr_name!r}" raise AttributeError(msg) + + +def adapt_vars_like( + source: AnnData, + target: AnnData, + fill_value: float = 0.0, +) -> AnnData: + """ + Adapt the `.var` structure of `target` to match that of `source`. + + This function makes sure that the `target` AnnData object has the same set + of genes (`.var_names`) as the `source` AnnData object. It fills in the + any missing genes in the `target` object with a specified `fill_value`. + + Parameters + ---------- + source + Reference AnnData object whose genes (.var) define the desired structure. + target + AnnData object to be adapted to match the source's gene structure. + fill_value + Value used to fill in missing genes. Defaults to 0.0. + + Returns + ------- + AnnData + A new AnnData object with the genes matching the source's structure and data from + `target`, with missing values filled in with `fill_value`. + + """ + # importing here to avoid circular import issues + from ._core.anndata import AnnData + from ._core.merge import Reindexer + + new_var = source.var.copy() + reindexer = Reindexer(target.var.index, new_var.index) + if target.X is None: + new_X = None + else: + new_X = reindexer(target.X, axis=1, fill_value=fill_value) + + new_layers = { + k: reindexer(v, axis=1, fill_value=fill_value) for k, v in target.layers.items() + } + + new_varm: AxisStorable = { + k: reindexer(v, axis=0, fill_value=fill_value) for k, v in target.varm.items() + } + + new_varp: AxisStorable = { + k: reindexer( + reindexer(v, axis=0, fill_value=fill_value), axis=1, fill_value=fill_value + ) + for k, v in target.varp.items() + } + + new_obsp = {k: v.copy() for k, v in target.obsp.items()} + + new_adata = AnnData( + X=new_X, + obs=target.obs.copy(), + var=new_var, + varm=new_varm, + varp=new_varp, + obsp=new_obsp, + layers=new_layers, + ) + + if target.raw is not None: + new_raw_X = reindexer(target.raw.X, axis=1, fill_value=fill_value) + new_adata.raw = AnnData( + X=new_raw_X, + var=source.var.copy(), + obs=target.obs.copy(), + ) + + return new_adata diff --git a/tests/test_utils.py b/tests/test_utils.py index 35cd246be..632accb8a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,13 +2,14 @@ from itertools import repeat +import numpy as np import pandas as pd import pytest from scipy import sparse import anndata as ad from anndata.tests.helpers import gen_typed_df -from anndata.utils import make_index_unique +from anndata.utils import adapt_vars_like, make_index_unique def test_make_index_unique() -> None: @@ -57,3 +58,127 @@ def test_adata_unique_indices(): pd.testing.assert_index_equal(v.obsm["df"].index, v.obs_names) pd.testing.assert_index_equal(v.varm["df"].index, v.var_names) + + +@pytest.mark.parametrize( + ("source", "target", "expected_X"), + [ + pytest.param( + ad.AnnData(X=np.ones((1, 3)), var=pd.DataFrame(index=["a", "b", "c"])), + ad.AnnData( + X=np.array([[1, 2, 3]]), var=pd.DataFrame(index=["a", "b", "c"]) + ), + np.array([[1, 2, 3]]), + id="exact_match", + ), + pytest.param( + ad.AnnData(X=np.ones((1, 3)), var=pd.DataFrame(index=["a", "b", "c"])), + ad.AnnData( + X=np.array([[3, 2, 1]]), var=pd.DataFrame(index=["c", "b", "a"]) + ), + np.array([[1, 2, 3]]), + id="different_order", + ), + ], +) +def test_adapt_vars(source, target, expected_X): + output = adapt_vars_like(source, target) + np.testing.assert_array_equal(output.X, expected_X) + assert list(output.var_names) == list(source.var_names) + + +@pytest.mark.parametrize( + ("source", "target", "fill_value", "expected_X"), + [ + pytest.param( + ad.AnnData(X=np.ones((1, 2)), var=pd.DataFrame(index=["g1", "g2"])), + ad.AnnData(X=np.array([[7, 8]]), var=pd.DataFrame(index=["g3", "g4"])), + 0.5, + np.array([[0.5, 0.5]]), + id="no_shared_genes", + ), + pytest.param( + ad.AnnData(X=np.ones((1, 3)), var=pd.DataFrame(index=["g1", "g2", "g3"])), + ad.AnnData(X=np.array([[1, 3]]), var=pd.DataFrame(index=["g1", "g3"])), + -1, + np.array([[1, -1, 3]]), + id="missing_genes", + ), + ], +) +def test_adapt_vars_with_fill_value(source, target, fill_value, expected_X): + output = adapt_vars_like(source, target, fill_value=fill_value) + np.testing.assert_array_equal(output.X, expected_X) + assert list(output.var_names) == list(source.var_names) + + +def test_adapt_vars_target_X_none(): + source = ad.AnnData( + X=np.ones((2, 2)), + var=pd.DataFrame(index=["g1", "g2"]), + ) + target = ad.AnnData( + X=None, + var=pd.DataFrame(index=["g2", "g3"]), + obs=pd.DataFrame(index=["cell1", "cell2"]), + ) + output = adapt_vars_like(source, target, fill_value=-1) + assert output.X is None + assert list(output.var_names) == list(source.var_names) + + +def test_adapt_vars_all_objects(): + source = ad.AnnData( + X=np.ones((2, 3)), + var=gen_typed_df(3, index=pd.Index(["a", "b", "c"])), + ) + + target = ad.AnnData( + X=np.array([[1, 3], [2, 4]]), + var=gen_typed_df(2, index=pd.Index(["a", "c"])), + obs=pd.DataFrame(index=["cell1", "cell2"]), + varm={"varm_key": np.array([[10, 11], [30, 31]])}, + varp={"varp_key": np.array([[1, 2], [3, 4]])}, + obsp={"obsp_key": np.array([[5, 6], [7, 8]])}, + layers={"layer1": np.array([[1000, 3000], [1001, 3001]])}, + ) + + output = adapt_vars_like(source, target, fill_value=-1) + + expected_X = np.array( + [ + [1, -1, 3], + [2, -1, 4], + ] + ) + np.testing.assert_array_equal(output.X, expected_X) + assert list(output.var_names) == ["a", "b", "c"] + + expected_layer = np.array( + [ + [1000, -1, 3000], + [1001, -1, 3001], + ] + ) + np.testing.assert_array_equal(output.layers["layer1"], expected_layer) + + expected_varm = np.array( + [ + [10, 11], + [-1, -1], + [30, 31], + ] + ) + np.testing.assert_array_equal(output.varm["varm_key"], expected_varm) + + expected_varp = np.array( + [ + [1, -1, 2], + [-1, -1, -1], + [3, -1, 4], + ] + ) + np.testing.assert_array_equal(output.varp["varp_key"], expected_varp) + + expected_obsp = np.array([[5, 6], [7, 8]]) + np.testing.assert_array_equal(output.obsp["obsp_key"], expected_obsp)