Skip to content

Commit fa751d7

Browse files
committed
ban slashes in keys where we can’t write
1 parent 07a6f67 commit fa751d7

5 files changed

Lines changed: 57 additions & 27 deletions

File tree

src/anndata/_io/h5ad.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .specs import read_elem, write_elem
2929
from .specs.registry import IOSpec, write_spec
3030
from .utils import (
31+
_check_has_no_slash_key,
3132
_read_legacy_raw,
3233
idx_chunks_along_axis,
3334
no_write_dataset_2d,
@@ -86,6 +87,8 @@ def write_h5ad(
8687
f.attrs.setdefault("encoding-type", "anndata")
8788
f.attrs.setdefault("encoding-version", "0.1.0")
8889
for k, elem in iter_outer(adata):
90+
_check_has_no_slash_key(k, elem)
91+
8992
if k == "raw":
9093
_write_raw(
9194
f, adata.raw, as_dense=as_dense, dataset_kwargs=dataset_kwargs

src/anndata/_io/specs/methods.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
from anndata._core.index import _normalize_indices
2121
from anndata._core.merge import intersect_keys
2222
from anndata._core.sparse_dataset import _CSCDataset, _CSRDataset, sparse_dataset
23-
from anndata._io.utils import check_key, zero_dim_array_as_scalar
23+
from anndata._io.utils import (
24+
_check_has_no_slash_key,
25+
check_key,
26+
zero_dim_array_as_scalar,
27+
)
2428
from anndata._warnings import OldFormatWarning
2529
from anndata.compat import (
2630
AwkArray,
@@ -287,19 +291,19 @@ def write_anndata(
287291
):
288292
g = f.require_group(k)
289293
for sub_key, elem in iter_outer(adata):
290-
if not (sub_key == "X" and elem is None):
291-
if sub_key == "layers":
292-
if None in elem:
293-
_writer.write_elem(
294-
g, "X", elem[None], dataset_kwargs=dataset_kwargs
295-
)
296-
elem = {k: v for k, v in elem.items() if k is not None}
297-
_writer.write_elem(
298-
g,
299-
sub_key,
300-
dict(elem) if isinstance(elem, MutableMapping) else elem,
301-
dataset_kwargs=dataset_kwargs,
302-
)
294+
if sub_key == "X" and elem is None:
295+
continue
296+
_check_has_no_slash_key(sub_key, elem)
297+
if sub_key == "layers":
298+
if None in elem:
299+
_writer.write_elem(g, "X", elem[None], dataset_kwargs=dataset_kwargs)
300+
elem = {k: v for k, v in elem.items() if k is not None}
301+
_writer.write_elem(
302+
g,
303+
sub_key,
304+
dict(elem) if isinstance(elem, MutableMapping) else elem,
305+
dataset_kwargs=dataset_kwargs,
306+
)
303307

304308

305309
@_REGISTRY.register_read(H5Group, IOSpec("anndata", "0.1.0"))

src/anndata/_io/specs/registry.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,10 @@ def write_elem(
368368
if k.startswith(store.name) and k != "/":
369369
k = str(PurePosixPath(k).relative_to(store.name))
370370

371+
# Apart from this code, we also ban keys containing slashes in `write_adata`/`write_h5ad`
372+
# for AnnData elements other than `obs`, `var`, and `uns`.
371373
if "/" in k and k != "/":
372-
if settings.disallow_forward_slash_in_h5ad:
374+
if isinstance(store, ZarrGroup) or settings.disallow_forward_slash_in_h5ad:
373375
msg = f"Forward slashes are not allowed in keys in {type(store)}"
374376
raise ValueError(msg)
375377
msg = "Forward slashes will be written differently in a future anndata version"

src/anndata/_io/utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from collections.abc import Callable
3+
from collections.abc import Callable, Mapping
44
from functools import WRAPPER_ASSIGNMENTS, cache, wraps
55
from itertools import pairwise
66
from typing import TYPE_CHECKING, Literal, cast
@@ -12,7 +12,7 @@
1212
from ..utils import warn
1313

1414
if TYPE_CHECKING:
15-
from collections.abc import Callable, Mapping
15+
from collections.abc import Callable
1616
from typing import Any, Literal
1717

1818
from pandas.core.dtypes.dtypes import BaseMaskedDtype
@@ -279,6 +279,16 @@ def func_wrapper(*args, **kwargs):
279279
return func_wrapper
280280

281281

282+
def _check_has_no_slash_key(attr: str, elem: object) -> None:
283+
"""Only attempt to write slash keys where people rely on it for backwards compatibility."""
284+
if attr in {"obs", "var", "uns", "raw"}:
285+
return # separate check for `settings.disallow_forward_slash_in_h5ad` is done in `write_elem`
286+
assert isinstance(elem, Mapping)
287+
if any("/" in k for k in elem if k not in {"/", None}):
288+
msg = f"Forward slashes are not allowed in keys in {attr}"
289+
raise ValueError(msg)
290+
291+
282292
# -------------------------------------------------------------------------------
283293
# Common h5ad/zarr stuff
284294
# -------------------------------------------------------------------------------

tests/test_readwrite.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,9 @@ def test_h5py_attr_limit(tmp_path):
963963
"elem_key", ["obs", "var", "obsm", "varm", "layers", "obsp", "varp", "uns"]
964964
)
965965
@pytest.mark.parametrize("store_type", ["zarr", "h5ad"])
966-
@pytest.mark.parametrize("disallow_forward_slash_in_h5ad", [True, False])
966+
@pytest.mark.parametrize(
967+
"disallow_forward_slash_in_h5ad", [True, False], ids=["ban_slash", "allow_slash"]
968+
)
967969
def test_forward_slash_key(
968970
elem_key: Literal["obs", "var", "obsm", "varm", "layers", "obsp", "varp", "uns"],
969971
tmp_path: Path,
@@ -976,25 +978,34 @@ def test_forward_slash_key(
976978
(10,) if elem_key in ["obs", "var"] else (10, 10)
977979
)
978980
path = tmp_path / f"test.{store_type}"
981+
can_write_slash_key = (
982+
elem_key in {"uns", "obs", "var"}
983+
and store_type == "h5ad"
984+
and not disallow_forward_slash_in_h5ad
985+
)
979986

987+
# try to write bad key and make sure we warn or throw an error
980988
with (
981989
ad.settings.override(
982990
disallow_forward_slash_in_h5ad=disallow_forward_slash_in_h5ad
983991
),
984-
pytest.raises(ValueError, match=r"Forward slashes")
985-
if disallow_forward_slash_in_h5ad
986-
else pytest.warns(FutureWarning, match=r"Forward slashes"),
992+
pytest.warns(FutureWarning, match=r"Forward slashes")
993+
if can_write_slash_key
994+
else pytest.raises(ValueError, match=r"Forward slashes"),
987995
):
988996
getattr(a, f"write_{store_type}")(path)
989997

990-
if not disallow_forward_slash_in_h5ad:
991-
adata = getattr(ad, f"read_{store_type}")(path)
998+
# read and check that bad keys were only written if allowed
999+
elem = getattr(getattr(ad, f"read_{store_type}")(path), elem_key)
1000+
if can_write_slash_key:
9921001
if elem_key in {"obs", "var"}:
993-
assert "bad/key" in getattr(adata, elem_key)
994-
elif elem_key == "uns":
995-
assert "bad" in getattr(adata, elem_key)
1002+
assert "bad/key" in elem
9961003
else:
997-
assert not getattr(adata, elem_key).keys()
1004+
assert "bad" in elem
1005+
elif elem_key in {"obs", "var"}:
1006+
assert set(elem.columns) == {"_index"}
1007+
else:
1008+
assert not elem.keys()
9981009

9991010

10001011
@pytest.mark.skipif(

0 commit comments

Comments
 (0)