|
12 | 12 | from scirpy.ir_dist.metrics import DistanceCalculator |
13 | 13 | from scirpy.util import DataHandler, _is_symmetric |
14 | 14 |
|
15 | | -from .util import _squarify |
| 15 | +from .util import _make_adata, _squarify |
16 | 16 |
|
17 | 17 |
|
18 | 18 | def _assert_frame_equal(left, right): |
@@ -163,6 +163,44 @@ def test_ir_dist( |
163 | 163 | npt.assert_array_equal(res["VDJ"]["distances"].toarray(), expected_dist_vdj) |
164 | 164 |
|
165 | 165 |
|
| 166 | +@pytest.mark.parametrize("mudata", [False, True], ids=["AnnData", "MuData"]) |
| 167 | +def test_ir_dist_tcrdist_tcrblosum_chain_routing(mudata): |
| 168 | + # `ir_dist` should automatically route VJ to TCRBLOSUM alpha and VDJ to beta. |
| 169 | + adata = _make_adata( |
| 170 | + pd.DataFrame( |
| 171 | + [ |
| 172 | + ["cell1", "AAACAAAA", "AAACAAAA", "TRA", "TRB"], |
| 173 | + ["cell2", "AAAHAAAA", "AAAHAAAA", "TRA", "TRB"], |
| 174 | + ], |
| 175 | + columns=[ |
| 176 | + "cell_id", |
| 177 | + "IR_VJ_1_junction_aa", |
| 178 | + "IR_VDJ_1_junction_aa", |
| 179 | + "IR_VJ_1_locus", |
| 180 | + "IR_VDJ_1_locus", |
| 181 | + ], |
| 182 | + ).set_index("cell_id"), |
| 183 | + mudata, |
| 184 | + ) |
| 185 | + |
| 186 | + ir.pp.ir_dist( |
| 187 | + adata, |
| 188 | + metric="tcrdist", |
| 189 | + sequence="aa", |
| 190 | + cutoff=20, |
| 191 | + base_matrix="tcrblosum", |
| 192 | + key_added="ir_dist_tcrblosum", |
| 193 | + n_jobs=1, |
| 194 | + ) |
| 195 | + |
| 196 | + res = adata.mod["airr"].uns["ir_dist_tcrblosum"] if isinstance(adata, MuData) else adata.uns["ir_dist_tcrblosum"] |
| 197 | + expected_seqs = np.array(["AAACAAAA", "AAAHAAAA"]) |
| 198 | + npt.assert_array_equal(res["VJ"]["seqs"], expected_seqs) |
| 199 | + npt.assert_array_equal(res["VDJ"]["seqs"], expected_seqs) |
| 200 | + npt.assert_array_equal(res["VJ"]["distances"].toarray(), np.array([[1, 16], [16, 1]])) |
| 201 | + npt.assert_array_equal(res["VDJ"]["distances"].toarray(), np.array([[1, 19], [19, 1]])) |
| 202 | + |
| 203 | + |
166 | 204 | @pytest.mark.parametrize("with_adata2", [False, True]) |
167 | 205 | @pytest.mark.parametrize("joblib_backend", ["loky", "multiprocessing", "threading"]) |
168 | 206 | @pytest.mark.parametrize("n_jobs", [1, 2]) |
|
0 commit comments