|
| 1 | +import matplotlib |
| 2 | +import matplotlib.pyplot as plt |
1 | 3 | import numpy as np |
2 | 4 | import pandas as pd |
3 | 5 | import pytest |
4 | 6 |
|
| 7 | +matplotlib.use("Agg") # Non-interactive backend for testing |
| 8 | + |
5 | 9 |
|
6 | 10 | class TestEvaluate: |
7 | 11 | @pytest.mark.parametrize("eval_layer", ["X", "counts"]) |
@@ -53,3 +57,238 @@ def test_presence_score_groupby(self, cmap, groupby): |
53 | 57 | # Columns should match group names |
54 | 58 | groups = cmap.query.obs[groupby].unique() |
55 | 59 | assert set(df.columns) == set(groups) |
| 60 | + |
| 61 | + |
| 62 | +class TestConfusionMatrix: |
| 63 | + """Tests for plot_confusion_matrix method.""" |
| 64 | + |
| 65 | + def test_plot_confusion_matrix_basic(self, cmap): |
| 66 | + """Test basic confusion matrix plotting without colors.""" |
| 67 | + cmap.map_obs(key="leiden") |
| 68 | + ax = cmap.plot_confusion_matrix("leiden", show_annotation_colors=False) |
| 69 | + assert ax is not None |
| 70 | + plt.close() |
| 71 | + |
| 72 | + def test_plot_confusion_matrix_with_colors(self, cmap): |
| 73 | + """Test confusion matrix plotting with annotation colors.""" |
| 74 | + cmap.map_obs(key="leiden") |
| 75 | + |
| 76 | + # Set explicit colors for testing |
| 77 | + query_cats = cmap.query.obs["leiden"].cat.categories |
| 78 | + ref_cats = cmap.reference.obs["leiden"].cat.categories |
| 79 | + |
| 80 | + # Generate distinct colors for query and reference |
| 81 | + query_colors = [f"#{i * 30:02x}0000" for i in range(len(query_cats))] # Red shades |
| 82 | + ref_colors = [f"#0000{i * 30:02x}" for i in range(len(ref_cats))] # Blue shades |
| 83 | + |
| 84 | + cmap.query.uns["leiden_colors"] = query_colors |
| 85 | + cmap.reference.uns["leiden_colors"] = ref_colors |
| 86 | + |
| 87 | + print(f"Query uns keys: {list(cmap.query.uns.keys())}") |
| 88 | + print(f"Reference uns keys: {list(cmap.reference.uns.keys())}") |
| 89 | + print(f"Query leiden categories: {list(query_cats)}") |
| 90 | + print(f"Query colors: {query_colors}") |
| 91 | + print(f"Reference leiden categories: {list(ref_cats)}") |
| 92 | + print(f"Reference colors: {ref_colors}") |
| 93 | + |
| 94 | + ax = cmap.plot_confusion_matrix("leiden", show_annotation_colors=True) |
| 95 | + assert ax is not None |
| 96 | + |
| 97 | + # Check that annotation strips were added (patches on the axes) |
| 98 | + patches = [p for p in ax.patches if hasattr(p, "get_facecolor")] |
| 99 | + n_rows = len(cmap.query.obs["leiden"].unique()) |
| 100 | + n_cols = len(cmap.reference.obs["leiden"].unique()) |
| 101 | + print(f"Number of patches: {len(patches)} (expected at least {n_rows + n_cols})") |
| 102 | + assert len(patches) >= n_rows + n_cols, f"Should have at least {n_rows + n_cols} patches" |
| 103 | + plt.close() |
| 104 | + |
| 105 | + def test_get_category_colors_helper(self, cmap): |
| 106 | + """Test _get_category_colors helper function directly.""" |
| 107 | + from cellmapper.model.evaluate import _get_category_colors |
| 108 | + |
| 109 | + cmap.map_obs(key="leiden") |
| 110 | + |
| 111 | + # Get categories from confusion matrix |
| 112 | + y_true = cmap.query.obs["leiden"].astype(str) |
| 113 | + y_pred = cmap.query.obs["leiden_pred"].astype(str) |
| 114 | + cm = pd.crosstab(y_true, y_pred) |
| 115 | + |
| 116 | + # Test getting colors for row categories (true labels from query) |
| 117 | + row_cats = list(cm.index) |
| 118 | + row_colors = _get_category_colors(cmap.query, "leiden", row_cats) |
| 119 | + print(f"Row categories: {row_cats}") |
| 120 | + print(f"Row colors: {row_colors}") |
| 121 | + assert len(row_colors) == len(row_cats) |
| 122 | + |
| 123 | + # Test getting colors for col categories (pred labels from reference) |
| 124 | + col_cats = list(cm.columns) |
| 125 | + col_colors = _get_category_colors(cmap.reference, "leiden", col_cats) |
| 126 | + print(f"Col categories: {col_cats}") |
| 127 | + print(f"Col colors: {col_colors}") |
| 128 | + assert len(col_colors) == len(col_cats) |
| 129 | + |
| 130 | + # Check that we're not getting all gray (which would mean colors not found) |
| 131 | + has_real_colors_row = any(c != "gray" for c in row_colors) |
| 132 | + has_real_colors_col = any(c != "gray" for c in col_colors) |
| 133 | + print(f"Row has real colors: {has_real_colors_row}") |
| 134 | + print(f"Col has real colors: {has_real_colors_col}") |
| 135 | + |
| 136 | + # If colors exist in .uns, they should be found |
| 137 | + if "leiden_colors" in cmap.query.uns: |
| 138 | + assert has_real_colors_row, "Should find colors in query.uns" |
| 139 | + if "leiden_colors" in cmap.reference.uns: |
| 140 | + assert has_real_colors_col, "Should find colors in reference.uns" |
| 141 | + |
| 142 | + def test_get_category_colors_with_explicit_colors(self, cmap): |
| 143 | + """Test _get_category_colors when colors are explicitly set.""" |
| 144 | + from cellmapper.model.evaluate import _get_category_colors |
| 145 | + |
| 146 | + cmap.map_obs(key="leiden") |
| 147 | + |
| 148 | + # Explicitly set colors in query and reference |
| 149 | + query_cats = cmap.query.obs["leiden"].cat.categories |
| 150 | + ref_cats = cmap.reference.obs["leiden"].cat.categories |
| 151 | + |
| 152 | + # Generate colors for query |
| 153 | + import matplotlib.pyplot as plt |
| 154 | + |
| 155 | + query_colors = plt.cm.tab10.colors[: len(query_cats)] |
| 156 | + cmap.query.uns["leiden_colors"] = [ |
| 157 | + f"#{int(r * 255):02x}{int(g * 255):02x}{int(b * 255):02x}" for r, g, b in query_colors |
| 158 | + ] |
| 159 | + |
| 160 | + # Generate different colors for reference |
| 161 | + ref_colors = plt.cm.Set3.colors[: len(ref_cats)] |
| 162 | + cmap.reference.uns["leiden_colors"] = [ |
| 163 | + f"#{int(r * 255):02x}{int(g * 255):02x}{int(b * 255):02x}" for r, g, b in ref_colors |
| 164 | + ] |
| 165 | + |
| 166 | + print(f"Query categories: {list(query_cats)}") |
| 167 | + print(f"Query colors: {cmap.query.uns['leiden_colors']}") |
| 168 | + print(f"Reference categories: {list(ref_cats)}") |
| 169 | + print(f"Reference colors: {cmap.reference.uns['leiden_colors']}") |
| 170 | + |
| 171 | + # Get categories from confusion matrix |
| 172 | + y_true = cmap.query.obs["leiden"].astype(str) |
| 173 | + y_pred = cmap.query.obs["leiden_pred"].astype(str) |
| 174 | + cm = pd.crosstab(y_true, y_pred) |
| 175 | + |
| 176 | + # Test getting colors for row categories (true labels from query) |
| 177 | + row_cats = list(cm.index) |
| 178 | + row_colors = _get_category_colors(cmap.query, "leiden", row_cats) |
| 179 | + print(f"CM row categories: {row_cats}") |
| 180 | + print(f"Row colors from query: {row_colors}") |
| 181 | + |
| 182 | + # Test getting colors for col categories (pred labels from reference) |
| 183 | + col_cats = list(cm.columns) |
| 184 | + col_colors = _get_category_colors(cmap.reference, "leiden", col_cats) |
| 185 | + print(f"CM col categories: {col_cats}") |
| 186 | + print(f"Col colors from reference: {col_colors}") |
| 187 | + |
| 188 | + # Verify colors were found (not gray) |
| 189 | + assert all(c != "gray" for c in row_colors), f"Row colors should not be gray: {row_colors}" |
| 190 | + assert all(c != "gray" for c in col_colors), f"Col colors should not be gray: {col_colors}" |
| 191 | + |
| 192 | + # Verify colors match the expected colors from their source adata |
| 193 | + for cat, color in zip(row_cats, row_colors, strict=True): |
| 194 | + cat_idx = list(query_cats).index(cat) |
| 195 | + expected = cmap.query.uns["leiden_colors"][cat_idx] |
| 196 | + assert color == expected, f"Row color mismatch for {cat}: {color} != {expected}" |
| 197 | + |
| 198 | + for cat, color in zip(col_cats, col_colors, strict=True): |
| 199 | + cat_idx = list(ref_cats).index(cat) |
| 200 | + expected = cmap.reference.uns["leiden_colors"][cat_idx] |
| 201 | + assert color == expected, f"Col color mismatch for {cat}: {color} != {expected}" |
| 202 | + |
| 203 | + def test_get_category_colors_mismatched_categories(self, cmap): |
| 204 | + """Test when query and reference have different category orders or subsets.""" |
| 205 | + from cellmapper.model.evaluate import _get_category_colors |
| 206 | + |
| 207 | + cmap.map_obs(key="leiden") |
| 208 | + |
| 209 | + # Check actual category differences between query and reference |
| 210 | + query_cats = set(cmap.query.obs["leiden"].cat.categories) |
| 211 | + ref_cats = set(cmap.reference.obs["leiden"].cat.categories) |
| 212 | + print(f"Query categories: {sorted(query_cats)}") |
| 213 | + print(f"Reference categories: {sorted(ref_cats)}") |
| 214 | + print(f"Only in query: {query_cats - ref_cats}") |
| 215 | + print(f"Only in reference: {ref_cats - query_cats}") |
| 216 | + |
| 217 | + # Set colors with same categories but DIFFERENT ORDER in .uns |
| 218 | + # This simulates what happens when scanpy generates colors independently |
| 219 | + query_cat_list = list(cmap.query.obs["leiden"].cat.categories) |
| 220 | + ref_cat_list = list(cmap.reference.obs["leiden"].cat.categories) |
| 221 | + |
| 222 | + # Query colors in natural order |
| 223 | + |
| 224 | + query_colors = [f"#query{i:02d}" for i in range(len(query_cat_list))] |
| 225 | + cmap.query.uns["leiden_colors"] = query_colors |
| 226 | + |
| 227 | + # Reference colors - same categories but colors assigned to different indices |
| 228 | + ref_colors = [f"#ref{i:02d}" for i in range(len(ref_cat_list))] |
| 229 | + cmap.reference.uns["leiden_colors"] = ref_colors |
| 230 | + |
| 231 | + print("\nQuery category -> color mapping:") |
| 232 | + for cat, col in zip(query_cat_list, query_colors, strict=True): |
| 233 | + print(f" {cat} -> {col}") |
| 234 | + |
| 235 | + print("\nReference category -> color mapping:") |
| 236 | + for cat, col in zip(ref_cat_list, ref_colors, strict=True): |
| 237 | + print(f" {cat} -> {col}") |
| 238 | + |
| 239 | + # Now test color retrieval |
| 240 | + test_cats = sorted(query_cats | ref_cats) |
| 241 | + row_colors = _get_category_colors(cmap.query, "leiden", test_cats) |
| 242 | + col_colors = _get_category_colors(cmap.reference, "leiden", test_cats) |
| 243 | + |
| 244 | + print(f"\nRetrieved colors for test_cats={test_cats}:") |
| 245 | + print(f"From query (rows): {row_colors}") |
| 246 | + print(f"From reference (cols): {col_colors}") |
| 247 | + |
| 248 | + # Verify each category maps to the correct color from its source |
| 249 | + for i, cat in enumerate(test_cats): |
| 250 | + if cat in query_cat_list: |
| 251 | + expected_row = query_colors[query_cat_list.index(cat)] |
| 252 | + assert row_colors[i] == expected_row, f"Query color wrong for {cat}" |
| 253 | + else: |
| 254 | + assert row_colors[i] == "gray", f"Missing query cat {cat} should be gray" |
| 255 | + |
| 256 | + if cat in ref_cat_list: |
| 257 | + expected_col = ref_colors[ref_cat_list.index(cat)] |
| 258 | + assert col_colors[i] == expected_col, f"Reference color wrong for {cat}" |
| 259 | + else: |
| 260 | + assert col_colors[i] == "gray", f"Missing ref cat {cat} should be gray" |
| 261 | + |
| 262 | + def test_get_category_colors_type_error(self): |
| 263 | + """Test that _get_category_colors raises TypeError for wrong input.""" |
| 264 | + from cellmapper.model.evaluate import _get_category_colors |
| 265 | + |
| 266 | + with pytest.raises(TypeError, match="Expected AnnData"): |
| 267 | + _get_category_colors([1, 2, 3], "leiden", ["A", "B"]) |
| 268 | + |
| 269 | + def test_plot_confusion_matrix_partial_colors(self, cmap): |
| 270 | + """Test when only reference has colors but query doesn't.""" |
| 271 | + cmap.map_obs(key="leiden") |
| 272 | + |
| 273 | + # Only set colors for reference (simulating common scenario where |
| 274 | + # query is spatial data without colors and reference is atlas with colors) |
| 275 | + ref_cats = cmap.reference.obs["leiden"].cat.categories |
| 276 | + ref_colors = [f"#00{i * 30:02x}00" for i in range(len(ref_cats))] # Green shades |
| 277 | + cmap.reference.uns["leiden_colors"] = ref_colors |
| 278 | + |
| 279 | + # Make sure query doesn't have colors |
| 280 | + if "leiden_colors" in cmap.query.uns: |
| 281 | + del cmap.query.uns["leiden_colors"] |
| 282 | + |
| 283 | + print(f"Query has leiden_colors: {'leiden_colors' in cmap.query.uns}") |
| 284 | + print(f"Reference has leiden_colors: {'leiden_colors' in cmap.reference.uns}") |
| 285 | + |
| 286 | + ax = cmap.plot_confusion_matrix("leiden", show_annotation_colors=True) |
| 287 | + patches = [p for p in ax.patches if hasattr(p, "get_facecolor")] |
| 288 | + |
| 289 | + # Should still have patches (some gray, some colored) |
| 290 | + n_rows = len(cmap.query.obs["leiden"].unique()) |
| 291 | + n_cols = len(cmap.reference.obs["leiden"].unique()) |
| 292 | + print(f"Number of patches: {len(patches)}") |
| 293 | + assert len(patches) >= n_rows + n_cols |
| 294 | + plt.close() |
0 commit comments