Skip to content

Commit a5e25fd

Browse files
committed
implement __contains__
1 parent 55ffb17 commit a5e25fd

3 files changed

Lines changed: 12 additions & 0 deletions

File tree

src/mudata/_core/mudata.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,14 @@ def __getitem__(self, index) -> AnnData | MuData:
547547
else:
548548
return MuData(self, as_view=True, index=index)
549549

550+
def __contains__(self, key) -> bool:
551+
if isinstance(key, str):
552+
return key in self._mod
553+
elif type(key).__module__.startswith("anndata.acc") and type(key).__name__ == "AdRef":
554+
return AnnData.__contains__.__get__(self)(key)
555+
else:
556+
raise TypeError(f"Unexpected key {key!r}.")
557+
550558
@property
551559
def mod(self) -> Mapping[str, AnnData | MuData]:
552560
"""Dictionary of modalities."""

tests/test_obs_var.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def test_names_make_unique(mdata: md.MuData):
153153
Version(ad.__version__) < Version("0.13dev0"), reason="anndata version too old, no accessor support"
154154
)
155155
def test_accessors(mdata: md.MuData):
156+
assert ad.acc.A.obs["arange"] in mdata
156157
assert (mdata[ad.acc.A.obs["arange"]] == mdata.obs["arange"]).all()
157158
with pytest.raises(KeyError, match="test"):
158159
mdata[ad.acc.A.var["test"]]

tests/test_update.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ def test_update_simple(mdata: MuData, axis: Axis):
146146
for mod in mdata.mod.keys():
147147
assert mdata.obsmap[mod].dtype.kind == "u"
148148
assert mdata.varmap[mod].dtype.kind == "u"
149+
assert mod in mdata
150+
with pytest.raises(TypeError):
151+
1 in mdata # noqa: B015
149152

150153
# names along non-axis are concatenated
151154
assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values())

0 commit comments

Comments
 (0)