Skip to content

Commit f5e8084

Browse files
committed
update the fix and the test case
1 parent 0af1761 commit f5e8084

2 files changed

Lines changed: 70 additions & 13 deletions

File tree

src/squidpy/gr/_ligrec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -731,12 +731,12 @@ def extractor(res: Sequence[TempResult]) -> TempResult:
731731

732732
return TempResult(means=means, pvalues=pvalues)
733733

734-
data = data.astype(np.float64, copy=False)
735734
groups = data.groupby("clusters", observed=True)
736735
clustering = np.array(data["clusters"].values, dtype=np.int32)
737736

738737
mean = groups.mean().values.T # (n_genes, n_clusters)
739-
mask = groups.apply(lambda c: ((c > 0).sum() / len(c)) >= threshold).values.T # (n_genes, n_clusters)
738+
mask = groups.apply(lambda c: ((c > 0).astype(int).sum() / len(c)) >= threshold).values.T # (n_genes, n_clusters)
739+
740740
# (n_cells, n_genes)
741741
data = np.array(data[data.columns.difference(["clusters"])].values, dtype=np.float64, order="C")
742742
# all 3 should be C contiguous

tests/graph/test_ligrec.py

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pandas.testing import assert_frame_equal
1515
from scanpy import settings as s
1616
from scanpy.datasets import blobs
17+
from scipy.sparse import csc_matrix
1718

1819
from squidpy._constants._pkg_constants import Key
1920
from squidpy.gr import ligrec
@@ -462,16 +463,72 @@ def test_none_source_target(self, adata: AnnData):
462463
assert isinstance(pt.interactions, pd.DataFrame)
463464
assert len(pt.interactions) == 1
464465

465-
def test_pvalues_nans(self, adata_hne: AnnData):
466+
def test_ligrec_nan_counts(self):
467+
"""
468+
For the test case with 2 clusters (A, B) and 3 gene pairs (Gene1→Gene2, Gene2→Gene3, Gene3→Gene1):
469+
470+
Expression fractions per cluster (computed as fraction of cells with value > 0):
471+
Cluster A:
472+
- Gene1: 1/3 = 0.33 (only A1 has value > 0)
473+
- Gene2: 2/3 = 0.67 (A2 and A3 have value > 0)
474+
- Gene3: 0/3 = 0.00 (none have value > 0)
475+
476+
Cluster B:
477+
- Gene1: 1/3 = 0.33 (only B1 has value > 0)
478+
- Gene2: 0/3 = 0.00 (none have value > 0)
479+
- Gene3: 3/3 = 1.00 (all have value > 0)
480+
481+
The mask is computed for each gene in each cluster as:
482+
mask[gene, cluster] = (number of cells with value > 0) / (total cells in cluster) >= threshold
483+
484+
With threshold=0.8, the mask is:
485+
Cluster A: [False, False, False] # All genes < 0.8 expression fraction
486+
Cluster B: [False, False, True] # Only Gene3 >= 0.8 expression fraction
487+
488+
A value in the result becomes NaN if either:
489+
- The ligand's mask is False in the source cluster, OR
490+
- The receptor's mask is False in the target cluster
491+
492+
For each cluster pair (A→A, A→B, B→A, B→B) and each gene pair:
493+
A→A: All NaN (all genes have mask=False in A)
494+
A→B: All NaN (all genes have mask=False in A)
495+
B→A: All NaN (all genes have mask=False in A)
496+
B→B: Only Gene2→Gene3 is non-NaN (Gene3 has mask=True in B)
497+
498+
Total NaNs = 4 cluster pairs x 3 gene pairs - 1 = 11 NaNs
499+
(The -1 is because B→B for Gene2→Gene3 is the only non-NaN case, where Gene3 has mask=True in B)
500+
"""
501+
502+
expected_nans = 11
503+
# Setup test data
504+
threshold = 0.8
505+
interactions = pd.DataFrame({"source": ["Gene1", "Gene2", "Gene3"], "target": ["Gene2", "Gene3", "Gene1"]})
506+
507+
# Create sparse matrix with test data
508+
X = csc_matrix(
509+
[
510+
[1.0, 0.1, 0.0], # A1
511+
[0.0, 1.0, 0.0], # A2
512+
[0.0, 1.0, 0.0], # A3
513+
[0.1, 0.0, 1.0], # B1
514+
[0.0, 0.0, 1.0], # B2
515+
[0.0, 0.0, 1.0], # B3
516+
]
517+
)
518+
519+
# Create AnnData object
520+
adata = AnnData(
521+
X=X,
522+
obs=pd.DataFrame({"cluster": ["A"] * 3 + ["B"] * 3}, index=[f"cell{i}" for i in range(1, 7)]),
523+
var=pd.DataFrame(index=["Gene1", "Gene2", "Gene3"]),
524+
)
525+
adata.obs["cluster"] = adata.obs["cluster"].astype("category")
526+
527+
# Run ligrec and compare NaN counts
466528
res = ligrec(
467-
adata_hne,
468-
_CK,
469-
seed=42,
470-
n_perms=5,
471-
show_progress_bar=False,
472-
use_raw=False,
473-
copy=True,
529+
adata, cluster_key="cluster", interactions=interactions, threshold=threshold, use_raw=False, copy=True
474530
)
475-
expected_num_nans_upper = 500_000
476-
num_nans = np.isnan(res["pvalues"].values).sum()
477-
assert num_nans < expected_num_nans_upper
531+
532+
actual_nans = np.sum(np.isnan(res["pvalues"].values))
533+
534+
assert actual_nans == expected_nans, f"NaN count mismatch: expected {expected_nans}, got {actual_nans}"

0 commit comments

Comments
 (0)