Skip to content

Commit a5e5761

Browse files
authored
feat: add layer param to filter_rank_genes_groups (#3999)
1 parent 6283892 commit a5e5761

3 files changed

Lines changed: 21 additions & 6 deletions

File tree

docs/release-notes/3999.feat.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add `layer` parameter to {func}`~scanpy.tl.filter_rank_genes_groups` {smaller}`P Angerer`

src/scanpy/tools/_rank_genes_groups.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
get_literal_vals,
1919
raise_not_implemented_error_if_backed_type,
2020
)
21-
from ..get import _check_mask
21+
from ..get import _check_mask, _get_obs_rep
2222

2323
if TYPE_CHECKING:
2424
from collections.abc import Generator, Iterable
@@ -763,6 +763,7 @@ def filter_rank_genes_groups( # noqa: PLR0912
763763
*,
764764
key: str | None = None,
765765
groupby: str | None = None,
766+
layer: str | None = None,
766767
use_raw: bool | None = None,
767768
key_added: str = "rank_genes_groups_filtered",
768769
min_in_group_fraction: float = 0.25,
@@ -789,6 +790,7 @@ def filter_rank_genes_groups( # noqa: PLR0912
789790
adata
790791
key
791792
groupby
793+
layer
792794
use_raw
793795
key_added
794796
min_in_group_fraction
@@ -799,8 +801,7 @@ def filter_rank_genes_groups( # noqa: PLR0912
799801
800802
Returns
801803
-------
802-
Same output as :func:`scanpy.tl.rank_genes_groups` but with filtered genes names set to
803-
`nan`
804+
Same output as :func:`scanpy.tl.rank_genes_groups` but with filtered genes names set to `nan`.
804805
805806
Examples
806807
--------
@@ -821,7 +822,9 @@ def filter_rank_genes_groups( # noqa: PLR0912
821822
groupby = adata.uns[key]["params"]["groupby"]
822823

823824
if use_raw is None:
824-
use_raw = adata.uns[key]["params"]["use_raw"]
825+
use_raw = adata.uns[key]["params"]["use_raw"] if layer is None else False
826+
827+
x = _get_obs_rep(adata, use_raw=use_raw, layer=layer)
825828

826829
same_params = (
827830
adata.uns[key]["params"]["groupby"] == groupby
@@ -872,7 +875,8 @@ def filter_rank_genes_groups( # noqa: PLR0912
872875
var_names = gene_names[cluster].values
873876

874877
if not use_logfolds or not use_fraction:
875-
sub_x = adata.raw[:, var_names].X if use_raw else adata[:, var_names].X
878+
var_idx = (adata.raw if use_raw else adata).var_names.get_indexer(var_names)
879+
sub_x = x[:, var_idx]
876880
in_group = (adata.obs[groupby] == cluster).to_numpy()
877881
x_in = sub_x[in_group]
878882
x_out = sub_x[~in_group]

tests/test_filter_rank_genes_groups.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,30 @@
6262
pytest.param("rest", True, True, id="rest-pts-abs"),
6363
],
6464
)
65-
def test_filter_rank_genes_groups(reference, pts, abs):
65+
@pytest.mark.parametrize("layer", [None, "layer"], ids=["raw", "layer"])
66+
def test_filter_rank_genes_groups(
67+
*, reference: str, pts: bool, abs: bool, layer: str | None
68+
) -> None:
6669
adata = pbmc68k_reduced()
70+
if layer is not None:
71+
adata.layers[layer] = adata.raw.X
72+
del adata.X
73+
del adata.raw
6774

6875
rank_genes_groups(
6976
adata,
7077
"bulk_labels",
7178
reference=reference,
7279
pts=pts,
7380
method="wilcoxon",
81+
layer=layer,
7482
rankby_abs=abs,
7583
n_genes=5,
7684
)
7785
if abs:
7886
filter_rank_genes_groups(
7987
adata,
88+
layer=layer,
8089
compare_abs=True,
8190
min_in_group_fraction=-1,
8291
max_out_group_fraction=1,
@@ -85,6 +94,7 @@ def test_filter_rank_genes_groups(reference, pts, abs):
8594
else:
8695
filter_rank_genes_groups(
8796
adata,
97+
layer=layer,
8898
min_in_group_fraction=0.25,
8999
min_fold_change=1,
90100
max_out_group_fraction=0.5,

0 commit comments

Comments
 (0)