Skip to content

Commit 6f4bc3b

Browse files
committed
disallow / keys
1 parent 7117ac1 commit 6f4bc3b

3 files changed

Lines changed: 64 additions & 28 deletions

File tree

src/anndata/_io/specs/registry.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -364,41 +364,47 @@ def write_elem(
364364

365365
from anndata._io.zarr import is_group_consolidated
366366

367-
# we allow stores to have a prefix like /uns which are then written to with keys like /uns/foo
368-
if k.startswith(store.name) and k != "/":
369-
k = str(PurePosixPath(k).relative_to(store.name))
370-
371-
# Apart from this code, we also ban keys containing slashes in `write_adata`/`write_h5ad`
372-
# for AnnData elements other than `obs`, `var`, and `uns`.
373-
if "/" in k and k != "/":
374-
if isinstance(store, ZarrGroup) or settings.disallow_forward_slash_in_h5ad:
375-
msg = f"Forward slashes are not allowed in keys in {type(store)}"
376-
raise ValueError(msg)
377-
msg = "Forward slashes will be written differently in a future anndata version"
378-
warn(msg, FutureWarning)
379-
380367
if isinstance(store, h5py.File):
381368
store = store["/"]
382-
383-
dest_type = type(store)
384-
385-
if is_group_consolidated(store, strict=False):
369+
elif is_group_consolidated(store, strict=False):
386370
msg = "Cannot overwrite/edit a store with consolidated metadata"
387371
raise ValueError(msg)
372+
388373
if k == "/":
374+
if store.name != "/":
375+
msg = f"'/' is not in the subpath of {store.name!r}"
376+
raise ValueError(msg)
377+
389378
if isinstance(store, ZarrGroup):
390379
from zarr.core.sync import sync
391380

392381
sync(store.store.clear())
393382
else:
394383
store.clear()
395-
elif k in store:
396-
del store[k]
384+
else:
385+
# we allow stores to have a prefix like /uns which are then written to with keys like /uns/foo
386+
if k.startswith("/"):
387+
k = str(PurePosixPath(k).relative_to(store.name, walk_up=False))
388+
389+
# Apart from this code, we also ban keys containing slashes in `write_adata`/`write_h5ad`
390+
# for AnnData elements other than `obs`, `var`, and `uns`.
391+
if "/" in k:
392+
if (
393+
isinstance(store, ZarrGroup)
394+
or settings.disallow_forward_slash_in_h5ad
395+
):
396+
msg = f"Forward slashes are not allowed in keys in {type(store)}"
397+
raise ValueError(msg)
398+
msg = "Forward slashes will be written differently in a future anndata version"
399+
warn(msg, FutureWarning)
400+
401+
if k in store:
402+
del store[k]
397403

398404
# Normalize array-API (e.g., JAX/CuPy) even if not AnnData
399405
elem = normalize_nested(elem)
400406

401-
write_func = self.find_write_func(dest_type, elem, modifiers)
407+
write_func = self.find_write_func(type(store), elem, modifiers)
402408

403409
if self.callback is None:
404410
return write_func(store, k, elem, dataset_kwargs=dataset_kwargs)

tests/test_io_elementwise.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -690,9 +690,7 @@ def test_categorical_order_type(store):
690690

691691

692692
def test_override_specification():
693-
"""
694-
Test that trying to overwrite an existing encoding raises an error.
695-
"""
693+
"""Test that trying to overwrite an existing encoding raises an error."""
696694
from copy import deepcopy
697695

698696
registry = deepcopy(_REGISTRY)

tests/test_readwrite.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import anndata as ad
2020
from anndata._io.specs.registry import IORegistryError
2121
from anndata._io.zarr import open_write_group
22+
from anndata._types import AnnDataElem
2223
from anndata.compat import (
2324
CSArray,
2425
CSMatrix,
@@ -35,6 +36,7 @@
3536
jnp,
3637
jnp_array_or_idempotent,
3738
)
39+
from anndata.utils import get_literal_members
3840

3941
if TYPE_CHECKING:
4042
from collections.abc import Callable, Generator
@@ -960,20 +962,20 @@ def test_h5py_attr_limit(tmp_path):
960962

961963

962964
@pytest.mark.parametrize(
963-
"elem_key", ["obs", "var", "obsm", "varm", "layers", "obsp", "varp", "uns"]
965+
"elem_key", set(get_literal_members(AnnDataElem)) - {"raw", "X"}
964966
)
965967
@pytest.mark.parametrize("store_type", ["zarr", "h5ad"])
966968
@pytest.mark.parametrize(
967969
"disallow_forward_slash_in_h5ad", [True, False], ids=["ban_slash", "allow_slash"]
968970
)
969971
def test_forward_slash_key(
970-
elem_key: Literal["obs", "var", "obsm", "varm", "layers", "obsp", "varp", "uns"],
972+
*,
971973
tmp_path: Path,
974+
elem_key: AnnDataElem,
972975
store_type: Literal["zarr", "h5ad"],
973-
*,
974976
disallow_forward_slash_in_h5ad: bool,
975-
):
976-
a = ad.AnnData(np.ones((10, 10)))
977+
) -> None:
978+
a = ad.AnnData(shape=(10, 10))
977979
getattr(a, elem_key)["bad/key"] = np.ones(
978980
(10,) if elem_key in ["obs", "var"] else (10, 10)
979981
)
@@ -1008,6 +1010,36 @@ def test_forward_slash_key(
10081010
assert not elem.keys()
10091011

10101012

1013+
@pytest.mark.parametrize(
1014+
"elem_key", set(get_literal_members(AnnDataElem)) - {"raw", "X"}
1015+
)
1016+
@pytest.mark.parametrize("store_type", ["zarr", "h5ad"])
1017+
@pytest.mark.parametrize("key", ["/", "/y"])
1018+
def test_leading_slash_error(
1019+
*,
1020+
tmp_path: Path,
1021+
elem_key: AnnDataElem,
1022+
store_type: Literal["zarr", "h5ad"],
1023+
key: str,
1024+
) -> None:
1025+
a = ad.AnnData(shape=(10, 10))
1026+
getattr(a, elem_key)[key] = np.ones(
1027+
(10,) if elem_key in ["obs", "var"] else (10, 10)
1028+
)
1029+
path = tmp_path / f"test.{store_type}"
1030+
1031+
# “not in the subpath” is raised by e.g. `write_elem(g["z"], "/y", ...)`,
1032+
# while “Forward slashes” is raised earlier by `write_anndata`/`write_h5ad`
1033+
with pytest.raises(ValueError, match=r"not in the subpath|Forward slashes"):
1034+
getattr(a, f"write_{store_type}")(path)
1035+
1036+
elem = getattr(getattr(ad, f"read_{store_type}")(path), elem_key)
1037+
if elem_key in {"obs", "var"}:
1038+
assert set(elem.columns) == {"_index"}
1039+
else:
1040+
assert not elem.keys()
1041+
1042+
10111043
@pytest.mark.skipif(
10121044
bool(find_spec("xarray")),
10131045
reason="Xarray is installed so `read_{elem_}lazy` will not error",

0 commit comments

Comments
 (0)