Skip to content

Commit 02ecd49

Browse files
committed
hvg works
1 parent bca40a8 commit 02ecd49

3 files changed

Lines changed: 26 additions & 11 deletions

File tree

src/scanpy/preprocessing/_highly_variable_genes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,10 @@ def _subset_genes(
411411
def _nth_highest(x: NDArray[np.float64] | DaskArray, n: int) -> float | DaskArray:
412412
x = x[~np.isnan(x)]
413413
if n > x.size:
414-
msg = "`n_top_genes` > number of normalized dispersions, returning all genes with normalized dispersions."
414+
msg = (
415+
f"`n_top_genes` (={n}) > number of normalized dispersions (={x.size}), "
416+
"returning all genes with normalized dispersions."
417+
)
415418
# 5: caller -> 4: `highly_variable_genes` -> 3: `_…_single_batch` -> 2: `_subset_genes` -> 1: here
416419
warnings.warn(msg, UserWarning, stacklevel=5)
417420
n = x.size

src/testing/scanpy/_helpers/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def check_rep_results(func, X, *, fields: Iterable[str] = ("layer", "obsm"), **k
101101

102102

103103
def _check_check_values_warnings(
104-
function, adata, expected_warning, kwargs=MappingProxyType({})
104+
function, adata: AnnData, expected_warning: str, kwargs=MappingProxyType({})
105105
):
106106
"""Run `function` on `adata` with provided arguments `kwargs` twice.
107107
@@ -110,12 +110,14 @@ def _check_check_values_warnings(
110110
"""
111111
# expecting 0 no-int warnings
112112
with warnings.catch_warnings(record=True) as record:
113+
warnings.filterwarnings("always")
113114
function(adata.copy(), **kwargs, check_values=False)
114115
warning_msgs = [w.message.args[0] for w in record]
115116
assert expected_warning not in warning_msgs
116117

117118
# expecting 1 no-int warning
118119
with warnings.catch_warnings(record=True) as record:
120+
warnings.filterwarnings("always")
119121
function(adata.copy(), **kwargs, check_values=True)
120122
warning_msgs = [w.message.args[0] for w in record]
121123
assert expected_warning in warning_msgs

tests/test_highly_variable_genes.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import itertools
4+
from contextlib import nullcontext
45
from pathlib import Path
56
from string import ascii_letters
67
from typing import TYPE_CHECKING
@@ -391,11 +392,12 @@ def test_compare_to_upstream(
391392
sc.pp.log1p(pbmc)
392393
sc.pp.highly_variable_genes(pbmc, flavor=flavor, **params, inplace=True)
393394
elif func == "fgd":
394-
sc.pp.filter_genes_dispersion(
395-
pbmc, flavor=flavor, **params, log=True, subset=False
396-
)
395+
with pytest.warns(FutureWarning, match=r"sc\.pp\.highly_variable_genes"):
396+
sc.pp.filter_genes_dispersion(
397+
pbmc, flavor=flavor, **params, log=True, subset=False
398+
)
397399
else:
398-
raise AssertionError()
400+
pytest.fail(f"Unknown func {func}")
399401

400402
np.testing.assert_array_equal(
401403
hvg_info["highly_variable"], pbmc.var["highly_variable"]
@@ -505,7 +507,7 @@ def test_seurat_v3_warning():
505507

506508
def test_batches():
507509
adata = pbmc68k_reduced()
508-
adata[:100, :100].X = np.zeros((100, 100))
510+
adata.X[:100, :100] = np.zeros((100, 100))
509511

510512
adata.obs["batch"] = ["0" if i < 100 else "1" for i in range(adata.n_obs)]
511513
adata_1 = adata[adata.obs["batch"] == "0"].copy()
@@ -558,6 +560,7 @@ def test_batches():
558560
assert np.all(np.isin(colnames, hvg1.columns))
559561

560562

563+
@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning")
561564
def test_degenerate_batches():
562565
adata = AnnData(
563566
X=np.random.randn(10, 100),
@@ -692,10 +695,17 @@ def test_dask_consistency(adata: AnnData, flavor, batch_key, to_dask):
692695
adata_dask = adata.copy()
693696
adata_dask.X = to_dask(adata_dask.X)
694697

695-
output_mem, output_dask = (
696-
sc.pp.highly_variable_genes(ad, flavor=flavor, n_top_genes=15, inplace=False)
697-
for ad in [adata, adata_dask]
698-
)
698+
with (
699+
pytest.warns(UserWarning, match="n_top_genes.*normalized dispersions")
700+
if flavor == "cell_ranger"
701+
else nullcontext()
702+
):
703+
output_mem, output_dask = (
704+
sc.pp.highly_variable_genes(
705+
ad, flavor=flavor, n_top_genes=15, inplace=False
706+
)
707+
for ad in [adata, adata_dask]
708+
)
699709

700710
assert isinstance(output_mem, pd.DataFrame)
701711
assert isinstance(output_dask, pd.DataFrame)

0 commit comments

Comments
 (0)