Skip to content
4 changes: 2 additions & 2 deletions src/squidpy/gr/_ligrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,11 +735,11 @@ def extractor(res: Sequence[TempResult]) -> TempResult:
clustering = np.array(data["clusters"].values, dtype=np.int32)

mean = groups.mean().values.T # (n_genes, n_clusters)
mask = groups.apply(lambda c: ((c > 0).sum() / len(c)) >= threshold).values.T # (n_genes, n_clusters)
mask = groups.apply(lambda c: ((c > 0).astype(int).sum() / len(c)) >= threshold).values.T # (n_genes, n_clusters)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's link to the issue on why this astype cast is there. And use "int64" to be explicit (which is what int does anyway).


# (n_cells, n_genes)
data = np.array(data[data.columns.difference(["clusters"])].values, dtype=np.float64, order="C")
# all 3 should be C contiguous

return parallelize( # type: ignore[no-any-return]
_analysis_helper,
np.arange(n_perms, dtype=np.int32).tolist(),
Expand Down
71 changes: 71 additions & 0 deletions tests/graph/test_ligrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pandas.testing import assert_frame_equal
from scanpy import settings as s
from scanpy.datasets import blobs
from scipy.sparse import csc_matrix

from squidpy._constants._pkg_constants import Key
from squidpy.gr import ligrec
Expand Down Expand Up @@ -461,3 +462,73 @@ def test_none_source_target(self, adata: AnnData):
)
assert isinstance(pt.interactions, pd.DataFrame)
assert len(pt.interactions) == 1

def test_ligrec_nan_counts(self):
"""
For the test case with 2 clusters (A, B) and 3 gene pairs (Gene1→Gene2, Gene2→Gene3, Gene3→Gene1):

Expression fractions per cluster (computed as fraction of cells with value > 0):
Cluster A:
- Gene1: 1/3 = 0.33 (only A1 has value > 0)
- Gene2: 2/3 = 0.67 (A2 and A3 have value > 0)
- Gene3: 0/3 = 0.00 (none have value > 0)

Cluster B:
- Gene1: 1/3 = 0.33 (only B1 has value > 0)
- Gene2: 0/3 = 0.00 (none have value > 0)
- Gene3: 3/3 = 1.00 (all have value > 0)

The mask is computed for each gene in each cluster as:
mask[gene, cluster] = (number of cells with value > 0) / (total cells in cluster) >= threshold

With threshold=0.8, the mask is:
Cluster A: [False, False, False] # All genes < 0.8 expression fraction
Cluster B: [False, False, True] # Only Gene3 >= 0.8 expression fraction

A value in the result becomes NaN if either:
- The ligand's mask is False in the source cluster, OR
- The receptor's mask is False in the target cluster
Comment on lines +485 to +487
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, doesn't this imply they both have to be True to be non-NaN? Wouldn't that exlude Gene2→Gene3?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right thanks for checking it out. Turns out I didn't fully understand it but I think I now got it and made the test a bit more specific as well. I also fixed the explanation.


For each cluster pair (A→A, A→B, B→A, B→B) and each gene pair:
A→A: All NaN (all genes have mask=False in A)
A→B: All NaN (all genes have mask=False in A)
B→A: All NaN (all genes have mask=False in A)
B→B: Only Gene2→Gene3 is non-NaN (Gene3 has mask=True in B)

Total NaNs = 4 cluster pairs x 3 gene pairs - 1 = 11 NaNs
(The -1 is because B→B for Gene2→Gene3 is the only non-NaN case, where Gene3 has mask=True in B)
"""

expected_nans = 11
# Setup test data
threshold = 0.8
interactions = pd.DataFrame({"source": ["Gene1", "Gene2", "Gene3"], "target": ["Gene2", "Gene3", "Gene1"]})

# Create sparse matrix with test data
X = csc_matrix(
[
[1.0, 0.1, 0.0], # A1
[0.0, 1.0, 0.0], # A2
[0.0, 1.0, 0.0], # A3
[0.1, 0.0, 1.0], # B1
[0.0, 0.0, 1.0], # B2
[0.0, 0.0, 1.0], # B3
]
)

# Create AnnData object
adata = AnnData(
X=X,
obs=pd.DataFrame({"cluster": ["A"] * 3 + ["B"] * 3}, index=[f"cell{i}" for i in range(1, 7)]),
var=pd.DataFrame(index=["Gene1", "Gene2", "Gene3"]),
)
adata.obs["cluster"] = adata.obs["cluster"].astype("category")

# Run ligrec and compare NaN counts
res = ligrec(
adata, cluster_key="cluster", interactions=interactions, threshold=threshold, use_raw=False, copy=True
)

actual_nans = np.sum(np.isnan(res["pvalues"].values))

assert actual_nans == expected_nans, f"NaN count mismatch: expected {expected_nans}, got {actual_nans}"
Loading