Skip to content

Commit f426717

Browse files
j-bacpre-commit-ci[bot]grst
authored
Fix bug with the layer parameter (#168)
* fix layer parameter * fix layer parameter * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make layer an optional parameter to prevent failing tests * Fix tests --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Gregor Sturm <mail@gregor-sturm.de>
1 parent 585c9ea commit f426717

2 files changed

Lines changed: 36 additions & 9 deletions

File tree

src/infercnvpy/tl/_infercnv.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def infercnv(
108108
var_mask = var_mask | adata.var["chromosome"].isin(exclude_chromosomes)
109109

110110
tmp_adata = adata[:, ~var_mask]
111-
reference = _get_reference(adata, reference_key, reference_cat, reference)[:, ~var_mask]
111+
reference = _get_reference(adata, reference_key, reference_cat, reference, layer)[:, ~var_mask]
112112

113113
expr = tmp_adata.X if layer is None else tmp_adata.layers[layer]
114114

@@ -358,6 +358,7 @@ def _get_reference(
358358
reference_key: str | None,
359359
reference_cat: None | str | Sequence[str],
360360
reference: np.ndarray | None,
361+
layer: str | None,
361362
) -> np.ndarray:
362363
"""Parameter validation extraction of reference gene expression.
363364
@@ -367,13 +368,18 @@ def _get_reference(
367368
Returns a 2D array with reference categories in rows, cells in columns.
368369
If there's just one category, it's still a 2D array.
369370
"""
371+
if layer is not None:
372+
X = adata.layers[layer]
373+
else:
374+
X = adata.X
375+
370376
if reference is None:
371377
if reference_key is None or reference_cat is None:
372378
logging.warning(
373379
"Using mean of all cells as reference. For better results, "
374380
"provide either `reference`, or both `reference_key` and `reference_cat`. "
375381
) # type: ignore
376-
reference = np.mean(adata.X, axis=0)
382+
reference = np.mean(X, axis=0)
377383

378384
else:
379385
obs_col = adata.obs[reference_key]
@@ -388,7 +394,7 @@ def _get_reference(
388394
f"{reference_cat[~reference_cat_in_obs]}"
389395
)
390396

391-
reference = np.vstack([np.mean(adata.X[obs_col.values == cat, :], axis=0) for cat in reference_cat])
397+
reference = np.vstack([np.mean(X[obs_col.values == cat, :], axis=0) for cat in reference_cat])
392398

393399
if reference.ndim == 1:
394400
reference = reference[np.newaxis, :]

tests/test_tools.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
def test_get_reference_key_and_cat(adata_mock):
1212
"""Test that reference is correctly calculated given a reference key and category"""
13-
actual = _get_reference(adata_mock, "cat", ["foo", "baz"], None)
13+
actual = _get_reference(adata_mock, "cat", ["foo", "baz"], None, layer=None)
1414
npt.assert_almost_equal(
1515
actual,
1616
np.array(
@@ -24,19 +24,19 @@ def test_get_reference_key_and_cat(adata_mock):
2424

2525
def test_get_reference_no_reference(adata_mock):
2626
"""If no reference is specified, the mean of the entire adata object is taken"""
27-
actual = _get_reference(adata_mock, None, None, None)
27+
actual = _get_reference(adata_mock, None, None, None, layer=None)
2828
npt.assert_almost_equal(actual, np.array([[4.8, 4.2, 4.4, 5]]), decimal=5)
2929

3030

3131
def test_get_reference_given_reference(adata_mock):
3232
"""Predefined reference takes precendence over reference_key and reference_cat"""
3333
reference = np.array([1, 2, 3, 4])
34-
actual = _get_reference(adata_mock, "foo", "bar", reference)
34+
actual = _get_reference(adata_mock, "foo", "bar", reference, layer=None)
3535
npt.assert_equal(reference, actual[0, :])
3636

3737
with pytest.raises(ValueError):
3838
reference = np.array([1, 2, 3])
39-
actual = _get_reference(adata_mock, "foo", "bar", reference)
39+
actual = _get_reference(adata_mock, "foo", "bar", reference, layer=None)
4040

4141

4242
@pytest.mark.parametrize(
@@ -141,7 +141,7 @@ def test_calculate_gene_averages():
141141

142142

143143
def test_infercnv_chunk_with_gene_values(adata_full_mock, gene_res_actual, x_res_actual):
144-
reference = _get_reference(adata_full_mock, reference_key=None, reference_cat=None, reference=None)
144+
reference = _get_reference(adata_full_mock, reference_key=None, reference_cat=None, reference=None, layer=None)
145145
var = adata_full_mock.var.loc[:, ["chromosome", "start", "end"]]
146146
tmp_x = adata_full_mock.X
147147

@@ -156,7 +156,7 @@ def test_infercnv_chunk_with_gene_values(adata_full_mock, gene_res_actual, x_res
156156

157157

158158
def test_infercnv_chunk_default(adata_full_mock, x_res_actual):
159-
reference = _get_reference(adata_full_mock, reference_key=None, reference_cat=None, reference=None)
159+
reference = _get_reference(adata_full_mock, reference_key=None, reference_cat=None, reference=None, layer=None)
160160
var = adata_full_mock.var.loc[:, ["chromosome", "start", "end"]]
161161
tmp_x = adata_full_mock.X
162162

@@ -218,6 +218,27 @@ def test_workflow(adata_oligodendroma):
218218
cnv.pl.chromosome_heatmap_summary(adata_oligodendroma, show=False)
219219

220220

221+
def test_layer_parameter(adata_oligodendroma):
222+
adata_oligodendroglioma = cnv.datasets.oligodendroglioma()
223+
224+
# create lognorm layer but leave X unchanged
225+
adata_oligodendroglioma.layers["LogNormalize"] = adata_oligodendroglioma.X.copy()
226+
sc.pp.normalize_total(adata_oligodendroglioma, layer="LogNormalize", target_sum=1e4)
227+
sc.pp.log1p(adata_oligodendroglioma, layer="LogNormalize")
228+
229+
# copy to test if results change with layer option
230+
adata_oligodendroglioma2 = adata_oligodendroglioma.copy()
231+
adata_oligodendroglioma2.X = adata_oligodendroglioma.layers["LogNormalize"]
232+
233+
cnv.tl.infercnv(adata_oligodendroglioma, layer="LogNormalize")
234+
cnv.tl.infercnv(adata_oligodendroglioma2, layer=None)
235+
236+
X_cnv = adata_oligodendroglioma.obsm["X_cnv"].toarray()
237+
X_cnv2 = adata_oligodendroglioma2.obsm["X_cnv"].toarray()
238+
239+
assert np.all(X_cnv == X_cnv2), "Different results found with infercnv layer parameter"
240+
241+
221242
def test_calculate_gene_values_speed(benchmark, adata_oligodendroma):
222243
# Benchmark with calculate_gene_values=True
223244
benchmark(

0 commit comments

Comments
 (0)