Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,20 @@ Writing formats that cannot represent all aspects of {class}`AnnData` objects.
AnnData.write_loom
```

(utilities-api)=

## Utilities

Helper functions used internationally or for reshaping and aligng `AnnData` objects. Can be useful for cusotm workflows or edge cases.

```{eval-rst}
.. autosummary::
:toctree: generated/

utils.adapt_vars_like

```

(experimental-api)=

## Experimental API
Expand Down
1 change: 1 addition & 0 deletions src/anndata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ def __getattr__(attr_name: str) -> Any:
"settings",
"types",
"typing",
"utils",
]
2 changes: 1 addition & 1 deletion src/anndata/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .compat import Index as _Index

if TYPE_CHECKING:
from typing import TypeAlias
from typing import AxisStorable, TypeAlias


__all__ = ["AxisStorable", "Index", "RWAble"]
Expand Down
79 changes: 78 additions & 1 deletion src/anndata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal
from typing import Any, AxisStorable, Literal

logger = get_logger(__name__)

Expand Down Expand Up @@ -449,3 +449,80 @@ def module_get_attr_redirect(
return getattr(mod, new_path)
msg = f"module {full_old_module_path} has no attribute {attr_name!r}"
raise AttributeError(msg)


def adapt_vars_like(
source: AnnData,
target: AnnData,
fill_value: float = 0.0,
) -> AnnData:
"""
Adapt the `.var` structure of `target` to match that of `source`.

This function makes sure that the `target` AnnData object has the same set
of genes (`.var_names`) as the `source` AnnData object. It fills in the
any missing genes in the `target` object with a specified `fill_value`.

Parameters
----------
source
Reference AnnData object whose genes (.var) define the desired structure.
target
AnnData object to be adapted to match the source's gene structure.
fill_value
Value used to fill in missing genes. Defaults to 0.0.

Returns
-------
AnnData
A new AnnData object with the genes matching the source's structure and data from
`target`, with missing values filled in with `fill_value`.

"""
# importing here to avoid circular import issues
from ._core.anndata import AnnData
from ._core.merge import Reindexer

new_var = source.var.copy()
reindexer = Reindexer(target.var.index, new_var.index)
if target.X is None:
new_X = None
else:
new_X = reindexer(target.X, axis=1, fill_value=fill_value)

new_layers = {
k: reindexer(v, axis=1, fill_value=fill_value) for k, v in target.layers.items()
}

new_varm: AxisStorable = {
k: reindexer(v, axis=0, fill_value=fill_value) for k, v in target.varm.items()
}

new_varp: AxisStorable = {
k: reindexer(
reindexer(v, axis=0, fill_value=fill_value), axis=1, fill_value=fill_value
)
for k, v in target.varp.items()
}

new_obsp = {k: v.copy() for k, v in target.obsp.items()}

new_adata = AnnData(
X=new_X,
obs=target.obs.copy(),
var=new_var,
varm=new_varm,
varp=new_varp,
obsp=new_obsp,
layers=new_layers,
)

if target.raw is not None:
new_raw_X = reindexer(target.raw.X, axis=1, fill_value=fill_value)
new_adata.raw = AnnData(
X=new_raw_X,
var=source.var.copy(),
obs=target.obs.copy(),
)

return new_adata
127 changes: 126 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from itertools import repeat

import numpy as np
import pandas as pd
import pytest
from scipy import sparse

import anndata as ad
from anndata.tests.helpers import gen_typed_df
from anndata.utils import make_index_unique
from anndata.utils import adapt_vars_like, make_index_unique


def test_make_index_unique() -> None:
Expand Down Expand Up @@ -57,3 +58,127 @@ def test_adata_unique_indices():

pd.testing.assert_index_equal(v.obsm["df"].index, v.obs_names)
pd.testing.assert_index_equal(v.varm["df"].index, v.var_names)


@pytest.mark.parametrize(
("source", "target", "expected_X"),
[
pytest.param(
ad.AnnData(X=np.ones((1, 3)), var=pd.DataFrame(index=["a", "b", "c"])),
ad.AnnData(
X=np.array([[1, 2, 3]]), var=pd.DataFrame(index=["a", "b", "c"])
),
np.array([[1, 2, 3]]),
id="exact_match",
),
pytest.param(
ad.AnnData(X=np.ones((1, 3)), var=pd.DataFrame(index=["a", "b", "c"])),
ad.AnnData(
X=np.array([[3, 2, 1]]), var=pd.DataFrame(index=["c", "b", "a"])
),
np.array([[1, 2, 3]]),
id="different_order",
),
],
)
def test_adapt_vars(source, target, expected_X):
output = adapt_vars_like(source, target)
np.testing.assert_array_equal(output.X, expected_X)
assert list(output.var_names) == list(source.var_names)


@pytest.mark.parametrize(
("source", "target", "fill_value", "expected_X"),
[
pytest.param(
ad.AnnData(X=np.ones((1, 2)), var=pd.DataFrame(index=["g1", "g2"])),
ad.AnnData(X=np.array([[7, 8]]), var=pd.DataFrame(index=["g3", "g4"])),
0.5,
np.array([[0.5, 0.5]]),
id="no_shared_genes",
),
pytest.param(
ad.AnnData(X=np.ones((1, 3)), var=pd.DataFrame(index=["g1", "g2", "g3"])),
ad.AnnData(X=np.array([[1, 3]]), var=pd.DataFrame(index=["g1", "g3"])),
-1,
np.array([[1, -1, 3]]),
id="missing_genes",
),
],
)
def test_adapt_vars_with_fill_value(source, target, fill_value, expected_X):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's merge this test and test_adapt_vars and have a fill value of None for the test_adapt_vars ones in the param

output = adapt_vars_like(source, target, fill_value=fill_value)
np.testing.assert_array_equal(output.X, expected_X)
assert list(output.var_names) == list(source.var_names)


def test_adapt_vars_target_X_none():
source = ad.AnnData(
X=np.ones((2, 2)),
var=pd.DataFrame(index=["g1", "g2"]),
)
target = ad.AnnData(
X=None,
var=pd.DataFrame(index=["g2", "g3"]),
obs=pd.DataFrame(index=["cell1", "cell2"]),
)
output = adapt_vars_like(source, target, fill_value=-1)
assert output.X is None
assert list(output.var_names) == list(source.var_names)


def test_adapt_vars_all_objects():
source = ad.AnnData(
X=np.ones((2, 3)),
var=gen_typed_df(3, index=pd.Index(["a", "b", "c"])),
)

target = ad.AnnData(
X=np.array([[1, 3], [2, 4]]),
var=gen_typed_df(2, index=pd.Index(["a", "c"])),
obs=pd.DataFrame(index=["cell1", "cell2"]),
varm={"varm_key": np.array([[10, 11], [30, 31]])},
varp={"varp_key": np.array([[1, 2], [3, 4]])},
obsp={"obsp_key": np.array([[5, 6], [7, 8]])},
layers={"layer1": np.array([[1000, 3000], [1001, 3001]])},
)

output = adapt_vars_like(source, target, fill_value=-1)

expected_X = np.array(
[
[1, -1, 3],
[2, -1, 4],
]
)
np.testing.assert_array_equal(output.X, expected_X)
assert list(output.var_names) == ["a", "b", "c"]

expected_layer = np.array(
[
[1000, -1, 3000],
[1001, -1, 3001],
]
)
np.testing.assert_array_equal(output.layers["layer1"], expected_layer)

expected_varm = np.array(
[
[10, 11],
[-1, -1],
[30, 31],
]
)
np.testing.assert_array_equal(output.varm["varm_key"], expected_varm)

expected_varp = np.array(
[
[1, -1, 2],
[-1, -1, -1],
[3, -1, 4],
]
)
np.testing.assert_array_equal(output.varp["varp_key"], expected_varp)

expected_obsp = np.array([[5, 6], [7, 8]])
np.testing.assert_array_equal(output.obsp["obsp_key"], expected_obsp)
Loading