Skip to content

Commit 3a61b7a

Browse files
committed
Test ir_dist TCRBLOSUM chain routing
1 parent fae4a00 commit 3a61b7a

1 file changed

Lines changed: 39 additions & 1 deletion

File tree

src/scirpy/tests/test_ir_dist.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from scirpy.ir_dist.metrics import DistanceCalculator
1313
from scirpy.util import DataHandler, _is_symmetric
1414

15-
from .util import _squarify
15+
from .util import _make_adata, _squarify
1616

1717

1818
def _assert_frame_equal(left, right):
@@ -163,6 +163,44 @@ def test_ir_dist(
163163
npt.assert_array_equal(res["VDJ"]["distances"].toarray(), expected_dist_vdj)
164164

165165

166+
@pytest.mark.parametrize("mudata", [False, True], ids=["AnnData", "MuData"])
167+
def test_ir_dist_tcrdist_tcrblosum_chain_routing(mudata):
168+
# `ir_dist` should automatically route VJ to TCRBLOSUM alpha and VDJ to beta.
169+
adata = _make_adata(
170+
pd.DataFrame(
171+
[
172+
["cell1", "AAACAAAA", "AAACAAAA", "TRA", "TRB"],
173+
["cell2", "AAAHAAAA", "AAAHAAAA", "TRA", "TRB"],
174+
],
175+
columns=[
176+
"cell_id",
177+
"IR_VJ_1_junction_aa",
178+
"IR_VDJ_1_junction_aa",
179+
"IR_VJ_1_locus",
180+
"IR_VDJ_1_locus",
181+
],
182+
).set_index("cell_id"),
183+
mudata,
184+
)
185+
186+
ir.pp.ir_dist(
187+
adata,
188+
metric="tcrdist",
189+
sequence="aa",
190+
cutoff=20,
191+
base_matrix="tcrblosum",
192+
key_added="ir_dist_tcrblosum",
193+
n_jobs=1,
194+
)
195+
196+
res = adata.mod["airr"].uns["ir_dist_tcrblosum"] if isinstance(adata, MuData) else adata.uns["ir_dist_tcrblosum"]
197+
expected_seqs = np.array(["AAACAAAA", "AAAHAAAA"])
198+
npt.assert_array_equal(res["VJ"]["seqs"], expected_seqs)
199+
npt.assert_array_equal(res["VDJ"]["seqs"], expected_seqs)
200+
npt.assert_array_equal(res["VJ"]["distances"].toarray(), np.array([[1, 16], [16, 1]]))
201+
npt.assert_array_equal(res["VDJ"]["distances"].toarray(), np.array([[1, 19], [19, 1]]))
202+
203+
166204
@pytest.mark.parametrize("with_adata2", [False, True])
167205
@pytest.mark.parametrize("joblib_backend", ["loky", "multiprocessing", "threading"])
168206
@pytest.mark.parametrize("n_jobs", [1, 2])

0 commit comments

Comments
 (0)