Skip to content

Commit ffcb201

Browse files
committed
test: add comprehensive tests for plot_confusion_matrix colors
1 parent 26c5876 commit ffcb201

File tree

1 file changed

+239
-0
lines changed

1 file changed

+239
-0
lines changed

tests/model/test_evaluate.py

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
import matplotlib
2+
import matplotlib.pyplot as plt
13
import numpy as np
24
import pandas as pd
35
import pytest
46

7+
matplotlib.use("Agg") # Non-interactive backend for testing
8+
59

610
class TestEvaluate:
711
@pytest.mark.parametrize("eval_layer", ["X", "counts"])
@@ -53,3 +57,238 @@ def test_presence_score_groupby(self, cmap, groupby):
5357
# Columns should match group names
5458
groups = cmap.query.obs[groupby].unique()
5559
assert set(df.columns) == set(groups)
60+
61+
62+
class TestConfusionMatrix:
63+
"""Tests for plot_confusion_matrix method."""
64+
65+
def test_plot_confusion_matrix_basic(self, cmap):
66+
"""Test basic confusion matrix plotting without colors."""
67+
cmap.map_obs(key="leiden")
68+
ax = cmap.plot_confusion_matrix("leiden", show_annotation_colors=False)
69+
assert ax is not None
70+
plt.close()
71+
72+
def test_plot_confusion_matrix_with_colors(self, cmap):
73+
"""Test confusion matrix plotting with annotation colors."""
74+
cmap.map_obs(key="leiden")
75+
76+
# Set explicit colors for testing
77+
query_cats = cmap.query.obs["leiden"].cat.categories
78+
ref_cats = cmap.reference.obs["leiden"].cat.categories
79+
80+
# Generate distinct colors for query and reference
81+
query_colors = [f"#{i * 30:02x}0000" for i in range(len(query_cats))] # Red shades
82+
ref_colors = [f"#0000{i * 30:02x}" for i in range(len(ref_cats))] # Blue shades
83+
84+
cmap.query.uns["leiden_colors"] = query_colors
85+
cmap.reference.uns["leiden_colors"] = ref_colors
86+
87+
print(f"Query uns keys: {list(cmap.query.uns.keys())}")
88+
print(f"Reference uns keys: {list(cmap.reference.uns.keys())}")
89+
print(f"Query leiden categories: {list(query_cats)}")
90+
print(f"Query colors: {query_colors}")
91+
print(f"Reference leiden categories: {list(ref_cats)}")
92+
print(f"Reference colors: {ref_colors}")
93+
94+
ax = cmap.plot_confusion_matrix("leiden", show_annotation_colors=True)
95+
assert ax is not None
96+
97+
# Check that annotation strips were added (patches on the axes)
98+
patches = [p for p in ax.patches if hasattr(p, "get_facecolor")]
99+
n_rows = len(cmap.query.obs["leiden"].unique())
100+
n_cols = len(cmap.reference.obs["leiden"].unique())
101+
print(f"Number of patches: {len(patches)} (expected at least {n_rows + n_cols})")
102+
assert len(patches) >= n_rows + n_cols, f"Should have at least {n_rows + n_cols} patches"
103+
plt.close()
104+
105+
def test_get_category_colors_helper(self, cmap):
106+
"""Test _get_category_colors helper function directly."""
107+
from cellmapper.model.evaluate import _get_category_colors
108+
109+
cmap.map_obs(key="leiden")
110+
111+
# Get categories from confusion matrix
112+
y_true = cmap.query.obs["leiden"].astype(str)
113+
y_pred = cmap.query.obs["leiden_pred"].astype(str)
114+
cm = pd.crosstab(y_true, y_pred)
115+
116+
# Test getting colors for row categories (true labels from query)
117+
row_cats = list(cm.index)
118+
row_colors = _get_category_colors(cmap.query, "leiden", row_cats)
119+
print(f"Row categories: {row_cats}")
120+
print(f"Row colors: {row_colors}")
121+
assert len(row_colors) == len(row_cats)
122+
123+
# Test getting colors for col categories (pred labels from reference)
124+
col_cats = list(cm.columns)
125+
col_colors = _get_category_colors(cmap.reference, "leiden", col_cats)
126+
print(f"Col categories: {col_cats}")
127+
print(f"Col colors: {col_colors}")
128+
assert len(col_colors) == len(col_cats)
129+
130+
# Check that we're not getting all gray (which would mean colors not found)
131+
has_real_colors_row = any(c != "gray" for c in row_colors)
132+
has_real_colors_col = any(c != "gray" for c in col_colors)
133+
print(f"Row has real colors: {has_real_colors_row}")
134+
print(f"Col has real colors: {has_real_colors_col}")
135+
136+
# If colors exist in .uns, they should be found
137+
if "leiden_colors" in cmap.query.uns:
138+
assert has_real_colors_row, "Should find colors in query.uns"
139+
if "leiden_colors" in cmap.reference.uns:
140+
assert has_real_colors_col, "Should find colors in reference.uns"
141+
142+
def test_get_category_colors_with_explicit_colors(self, cmap):
143+
"""Test _get_category_colors when colors are explicitly set."""
144+
from cellmapper.model.evaluate import _get_category_colors
145+
146+
cmap.map_obs(key="leiden")
147+
148+
# Explicitly set colors in query and reference
149+
query_cats = cmap.query.obs["leiden"].cat.categories
150+
ref_cats = cmap.reference.obs["leiden"].cat.categories
151+
152+
# Generate colors for query
153+
import matplotlib.pyplot as plt
154+
155+
query_colors = plt.cm.tab10.colors[: len(query_cats)]
156+
cmap.query.uns["leiden_colors"] = [
157+
f"#{int(r * 255):02x}{int(g * 255):02x}{int(b * 255):02x}" for r, g, b in query_colors
158+
]
159+
160+
# Generate different colors for reference
161+
ref_colors = plt.cm.Set3.colors[: len(ref_cats)]
162+
cmap.reference.uns["leiden_colors"] = [
163+
f"#{int(r * 255):02x}{int(g * 255):02x}{int(b * 255):02x}" for r, g, b in ref_colors
164+
]
165+
166+
print(f"Query categories: {list(query_cats)}")
167+
print(f"Query colors: {cmap.query.uns['leiden_colors']}")
168+
print(f"Reference categories: {list(ref_cats)}")
169+
print(f"Reference colors: {cmap.reference.uns['leiden_colors']}")
170+
171+
# Get categories from confusion matrix
172+
y_true = cmap.query.obs["leiden"].astype(str)
173+
y_pred = cmap.query.obs["leiden_pred"].astype(str)
174+
cm = pd.crosstab(y_true, y_pred)
175+
176+
# Test getting colors for row categories (true labels from query)
177+
row_cats = list(cm.index)
178+
row_colors = _get_category_colors(cmap.query, "leiden", row_cats)
179+
print(f"CM row categories: {row_cats}")
180+
print(f"Row colors from query: {row_colors}")
181+
182+
# Test getting colors for col categories (pred labels from reference)
183+
col_cats = list(cm.columns)
184+
col_colors = _get_category_colors(cmap.reference, "leiden", col_cats)
185+
print(f"CM col categories: {col_cats}")
186+
print(f"Col colors from reference: {col_colors}")
187+
188+
# Verify colors were found (not gray)
189+
assert all(c != "gray" for c in row_colors), f"Row colors should not be gray: {row_colors}"
190+
assert all(c != "gray" for c in col_colors), f"Col colors should not be gray: {col_colors}"
191+
192+
# Verify colors match the expected colors from their source adata
193+
for cat, color in zip(row_cats, row_colors, strict=True):
194+
cat_idx = list(query_cats).index(cat)
195+
expected = cmap.query.uns["leiden_colors"][cat_idx]
196+
assert color == expected, f"Row color mismatch for {cat}: {color} != {expected}"
197+
198+
for cat, color in zip(col_cats, col_colors, strict=True):
199+
cat_idx = list(ref_cats).index(cat)
200+
expected = cmap.reference.uns["leiden_colors"][cat_idx]
201+
assert color == expected, f"Col color mismatch for {cat}: {color} != {expected}"
202+
203+
def test_get_category_colors_mismatched_categories(self, cmap):
204+
"""Test when query and reference have different category orders or subsets."""
205+
from cellmapper.model.evaluate import _get_category_colors
206+
207+
cmap.map_obs(key="leiden")
208+
209+
# Check actual category differences between query and reference
210+
query_cats = set(cmap.query.obs["leiden"].cat.categories)
211+
ref_cats = set(cmap.reference.obs["leiden"].cat.categories)
212+
print(f"Query categories: {sorted(query_cats)}")
213+
print(f"Reference categories: {sorted(ref_cats)}")
214+
print(f"Only in query: {query_cats - ref_cats}")
215+
print(f"Only in reference: {ref_cats - query_cats}")
216+
217+
# Set colors with same categories but DIFFERENT ORDER in .uns
218+
# This simulates what happens when scanpy generates colors independently
219+
query_cat_list = list(cmap.query.obs["leiden"].cat.categories)
220+
ref_cat_list = list(cmap.reference.obs["leiden"].cat.categories)
221+
222+
# Query colors in natural order
223+
224+
query_colors = [f"#query{i:02d}" for i in range(len(query_cat_list))]
225+
cmap.query.uns["leiden_colors"] = query_colors
226+
227+
# Reference colors - same categories but colors assigned to different indices
228+
ref_colors = [f"#ref{i:02d}" for i in range(len(ref_cat_list))]
229+
cmap.reference.uns["leiden_colors"] = ref_colors
230+
231+
print("\nQuery category -> color mapping:")
232+
for cat, col in zip(query_cat_list, query_colors, strict=True):
233+
print(f" {cat} -> {col}")
234+
235+
print("\nReference category -> color mapping:")
236+
for cat, col in zip(ref_cat_list, ref_colors, strict=True):
237+
print(f" {cat} -> {col}")
238+
239+
# Now test color retrieval
240+
test_cats = sorted(query_cats | ref_cats)
241+
row_colors = _get_category_colors(cmap.query, "leiden", test_cats)
242+
col_colors = _get_category_colors(cmap.reference, "leiden", test_cats)
243+
244+
print(f"\nRetrieved colors for test_cats={test_cats}:")
245+
print(f"From query (rows): {row_colors}")
246+
print(f"From reference (cols): {col_colors}")
247+
248+
# Verify each category maps to the correct color from its source
249+
for i, cat in enumerate(test_cats):
250+
if cat in query_cat_list:
251+
expected_row = query_colors[query_cat_list.index(cat)]
252+
assert row_colors[i] == expected_row, f"Query color wrong for {cat}"
253+
else:
254+
assert row_colors[i] == "gray", f"Missing query cat {cat} should be gray"
255+
256+
if cat in ref_cat_list:
257+
expected_col = ref_colors[ref_cat_list.index(cat)]
258+
assert col_colors[i] == expected_col, f"Reference color wrong for {cat}"
259+
else:
260+
assert col_colors[i] == "gray", f"Missing ref cat {cat} should be gray"
261+
262+
def test_get_category_colors_type_error(self):
263+
"""Test that _get_category_colors raises TypeError for wrong input."""
264+
from cellmapper.model.evaluate import _get_category_colors
265+
266+
with pytest.raises(TypeError, match="Expected AnnData"):
267+
_get_category_colors([1, 2, 3], "leiden", ["A", "B"])
268+
269+
def test_plot_confusion_matrix_partial_colors(self, cmap):
270+
"""Test when only reference has colors but query doesn't."""
271+
cmap.map_obs(key="leiden")
272+
273+
# Only set colors for reference (simulating common scenario where
274+
# query is spatial data without colors and reference is atlas with colors)
275+
ref_cats = cmap.reference.obs["leiden"].cat.categories
276+
ref_colors = [f"#00{i * 30:02x}00" for i in range(len(ref_cats))] # Green shades
277+
cmap.reference.uns["leiden_colors"] = ref_colors
278+
279+
# Make sure query doesn't have colors
280+
if "leiden_colors" in cmap.query.uns:
281+
del cmap.query.uns["leiden_colors"]
282+
283+
print(f"Query has leiden_colors: {'leiden_colors' in cmap.query.uns}")
284+
print(f"Reference has leiden_colors: {'leiden_colors' in cmap.reference.uns}")
285+
286+
ax = cmap.plot_confusion_matrix("leiden", show_annotation_colors=True)
287+
patches = [p for p in ax.patches if hasattr(p, "get_facecolor")]
288+
289+
# Should still have patches (some gray, some colored)
290+
n_rows = len(cmap.query.obs["leiden"].unique())
291+
n_cols = len(cmap.reference.obs["leiden"].unique())
292+
print(f"Number of patches: {len(patches)}")
293+
assert len(patches) >= n_rows + n_cols
294+
plt.close()

0 commit comments

Comments
 (0)