Skip to content

Commit 5ef1dea

Browse files
authored
Backport PR #2433 on branch 0.12.x (fix: correct IO types) (#2434)
1 parent 5d5858d commit 5ef1dea

6 files changed

Lines changed: 47 additions & 30 deletions

File tree

.vscode/settings.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,7 @@
2424
//"-nauto",
2525
],
2626
"python.terminal.activateEnvironment": true,
27+
"python.analysis.include": ["src/**/*", "ci/scripts/**/*", "tests/**/*"],
28+
"python-envs.defaultEnvManager": "pypa.hatch:hatch",
29+
"python-envs.defaultPackageManager": "pypa.hatch:hatch",
2730
}

src/anndata/_io/specs/lazy_methods.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,7 @@ def resolve_chunks(
202202
# In the long run, it might be good to figure out what exactly is going on here but for now, this will do.
203203
@_LAZY_REGISTRY.register_read(H5Array, IOSpec("string-array", "0.2.0"))
204204
def read_h5_string_array(
205-
elem: H5Array,
206-
*,
207-
_reader: LazyReader,
208-
chunks: tuple[int] | None = None,
205+
elem: H5Array, *, _reader: LazyReader, chunks: tuple[int, ...] | None = None
209206
) -> DaskArray:
210207
import dask.array as da
211208

@@ -249,7 +246,7 @@ def read_zarr_array(
249246
def _gen_xarray_dict_iterator_from_elems(
250247
elem_dict: dict[str, LazyDataStructures],
251248
dim_name: str,
252-
index: np.NDArray,
249+
index: np.typing.NDArray,
253250
) -> Generator[tuple[str, XVariable], None, None]:
254251
from anndata.experimental.backed._lazy_arrays import CategoricalArray, MaskedArray
255252

@@ -291,7 +288,7 @@ def read_dataframe(
291288
*,
292289
_reader: LazyReader,
293290
use_range_index: bool = False,
294-
chunks: tuple[int] | None = None,
291+
chunks: tuple[int, ...] | None = None,
295292
) -> Dataset2D:
296293
# going through dask for reading into memory the index doesn't make sense, hence the ternary.
297294
elem_dict = {
@@ -337,9 +334,12 @@ def read_categorical(
337334
elem: H5Group | ZarrGroup,
338335
*,
339336
_reader: LazyReader,
337+
chunks: tuple[int, ...] | None = None,
340338
) -> CategoricalArray:
341339
from anndata.experimental.backed._lazy_arrays import CategoricalArray
342340

341+
del chunks # ignored when reading groups
342+
343343
base_path_or_zarr_group = (
344344
Path(filename(elem)) if isinstance(elem, H5Group) else elem
345345
)
@@ -361,9 +361,12 @@ def read_nullable(
361361
"nullable-integer", "nullable-boolean", "nullable-string-array"
362362
],
363363
_reader: LazyReader,
364+
chunks: tuple[int, ...] | None = None,
364365
) -> MaskedArray:
365366
from anndata.experimental.backed._lazy_arrays import MaskedArray
366367

368+
del chunks # ignored when reading groups
369+
367370
base_path_or_zarr_group = (
368371
Path(filename(elem)) if isinstance(elem, H5Group) else elem
369372
)

src/anndata/_io/specs/methods.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from importlib.metadata import version
88
from itertools import product
99
from types import MappingProxyType
10-
from typing import TYPE_CHECKING
10+
from typing import TYPE_CHECKING, Protocol
1111
from warnings import warn
1212

1313
import h5py
@@ -46,7 +46,7 @@
4646
from .registry import _REGISTRY, IOSpec, read_elem, read_elem_partial
4747

4848
if TYPE_CHECKING:
49-
from collections.abc import Callable, Iterator
49+
from collections.abc import Iterator
5050
from os import PathLike
5151
from typing import Any, Literal
5252

@@ -443,7 +443,7 @@ def write_basic(
443443
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
444444
):
445445
"""Write methods which underlying library handles natively."""
446-
dataset_kwargs = dataset_kwargs.copy()
446+
dataset_kwargs = dict(dataset_kwargs)
447447
dtype = dataset_kwargs.pop("dtype", elem.dtype)
448448
if isinstance(f, H5Group) or is_zarr_v2():
449449
f.create_dataset(k, data=elem, shape=elem.shape, dtype=dtype, **dataset_kwargs)
@@ -525,7 +525,7 @@ def write_basic_dask_dask_dense(
525525
):
526526
import dask.array as da
527527

528-
dataset_kwargs = dataset_kwargs.copy()
528+
dataset_kwargs = dict(dataset_kwargs)
529529
is_h5 = isinstance(f, H5Group)
530530
if not is_h5:
531531
dataset_kwargs = zarr_v3_compressor_compat(dataset_kwargs)
@@ -629,7 +629,7 @@ def write_vlen_string_array_zarr(
629629
from numcodecs import VLenUTF8
630630
from zarr.core.dtype import VariableLengthUTF8
631631

632-
dataset_kwargs = dataset_kwargs.copy()
632+
dataset_kwargs = dict(dataset_kwargs)
633633
dataset_kwargs = zarr_v3_compressor_compat(dataset_kwargs)
634634
dtype = VariableLengthUTF8()
635635
filters, fill_value = None, None
@@ -702,7 +702,7 @@ def write_recarray_zarr(
702702
if is_zarr_v2():
703703
f.create_dataset(k, data=elem, shape=elem.shape, **dataset_kwargs)
704704
else:
705-
dataset_kwargs = dataset_kwargs.copy()
705+
dataset_kwargs = dict(dataset_kwargs)
706706
dataset_kwargs = zarr_v3_compressor_compat(dataset_kwargs)
707707
# https://github.com/zarr-developers/zarr-python/issues/3546
708708
# if "shards" not in dataset_kwargs and ad.settings.auto_shard_zarr_v3:
@@ -1210,14 +1210,14 @@ def write_nullable(
12101210
)(write_nullable)
12111211

12121212

1213+
class _BaseMaskedArray(Protocol):
1214+
def __call__(
1215+
self, values: NDArray[np.number], /, *, mask: NDArray[np.bool_]
1216+
) -> pd.api.extensions.ExtensionArray: ...
1217+
1218+
12131219
def _read_nullable(
1214-
elem: GroupStorageType,
1215-
*,
1216-
_reader: Reader,
1217-
# BaseMaskedArray
1218-
array_type: Callable[
1219-
[NDArray[np.number], NDArray[np.bool_]], pd.api.extensions.ExtensionArray
1220-
],
1220+
elem: GroupStorageType, *, _reader: Reader, array_type: _BaseMaskedArray
12211221
) -> pd.api.extensions.ExtensionArray:
12221222
return array_type(
12231223
_reader.read_elem(elem["values"]),
@@ -1378,7 +1378,7 @@ def write_string(
13781378
_writer: Writer,
13791379
dataset_kwargs: Mapping[str, Any],
13801380
):
1381-
dataset_kwargs = dataset_kwargs.copy()
1381+
dataset_kwargs = dict(dataset_kwargs)
13821382
dataset_kwargs.pop("compression", None)
13831383
dataset_kwargs.pop("compression_opts", None)
13841384
f.create_dataset(

src/anndata/_io/specs/registry.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from ..._core.xarray import Dataset2D
3232

33+
S = TypeVar("S")
3334
T = TypeVar("T")
3435
W = TypeVar("W", bound=_WriteInternal)
3536
LazyDataStructures = DaskArray | Dataset2D | CategoricalArray | MaskedArray
@@ -103,7 +104,7 @@ def register_write(
103104
src_type: type | tuple[type, str],
104105
spec: IOSpec | Mapping[str, str],
105106
modifiers: Iterable[str] = frozenset(),
106-
) -> Callable[[_WriteInternal[T]], _WriteInternal[T]]:
107+
) -> Callable[[_WriteInternal[S, T]], _WriteInternal[S, T]]:
107108
spec = proc_spec(spec)
108109
modifiers = frozenset(modifiers)
109110

@@ -119,7 +120,7 @@ def register_write(
119120
else:
120121
self.write_specs[src_type] = spec
121122

122-
def _register(func):
123+
def _register(func: _WriteInternal[S, T]) -> _WriteInternal[S, T]:
123124
self.write[(dest_type, src_type, modifiers)] = write_spec(spec)(func)
124125
return func
125126

@@ -351,7 +352,7 @@ def write_elem(
351352

352353
# we allow stores to have a prefix like /uns which are then written to with keys like /uns/foo
353354
is_zarr_group = isinstance(store, ZarrGroup)
354-
if "/" in k.split(store.name)[-1][1:]:
355+
if "/" in k.rsplit(store.name, maxsplit=1)[-1][1:]:
355356
if is_zarr_group or settings.disallow_forward_slash_in_h5ad:
356357
msg = f"Forward slashes are not allowed in keys in {type(store)}"
357358
raise ValueError(msg)

src/anndata/_io/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def _read_legacy_raw(
312312
return raw
313313

314314

315-
def zero_dim_array_as_scalar(func: _WriteInternal):
315+
def zero_dim_array_as_scalar(func: _WriteInternal) -> _WriteInternal:
316316
"""\
317317
A decorator for write_elem implementations of arrays where zero-dimensional arrays need special handling.
318318
"""

src/anndata/_types.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,16 @@
2525

2626

2727
__all__ = [
28+
"AnnDataElem",
2829
"ArrayStorageType",
30+
"Dataset2DIlocIndexer",
2931
"GroupStorageType",
32+
"Read",
33+
"ReadCallback",
34+
"ReadLazy",
3035
"StorageType",
36+
"Write",
37+
"WriteCallback",
3138
"_ReadInternal",
3239
"_ReadLazyInternal",
3340
"_WriteInternal",
@@ -47,25 +54,26 @@
4754

4855

4956
class Dataset2DIlocIndexer(Protocol):
50-
def __getitem__(self, idx: Any) -> Dataset2D: ...
57+
def __getitem__(self, idx: Any, /) -> Dataset2D: ...
5158

5259

5360
class _ReadInternal(Protocol[S_contra, RWAble_co]):
54-
def __call__(self, elem: S_contra, *, _reader: Reader) -> RWAble_co: ...
61+
def __call__(self, elem: S_contra, /, *, _reader: Reader) -> RWAble_co: ...
5562

5663

5764
class _ReadLazyInternal(Protocol[S_contra]):
5865
def __call__(
5966
self,
6067
elem: S_contra,
68+
/,
6169
*,
6270
_reader: LazyReader,
6371
chunks: tuple[int, ...] | None = None,
6472
) -> LazyDataStructures: ...
6573

6674

6775
class Read(Protocol[S_contra, RWAble_co]):
68-
def __call__(self, elem: S_contra) -> RWAble_co:
76+
def __call__(self, elem: S_contra, /) -> RWAble_co:
6977
"""Low-level reading function for an element.
7078
7179
Parameters
@@ -81,7 +89,7 @@ def __call__(self, elem: S_contra) -> RWAble_co:
8189

8290
class ReadLazy(Protocol[S_contra]):
8391
def __call__(
84-
self, elem: S_contra, *, chunks: tuple[int, ...] | None = None
92+
self, elem: S_contra, /, *, chunks: tuple[int, ...] | None = None
8593
) -> LazyDataStructures:
8694
"""Low-level reading function for a lazy element.
8795
@@ -98,12 +106,13 @@ def __call__(
98106
...
99107

100108

101-
class _WriteInternal(Protocol[RWAble_contra]):
109+
class _WriteInternal(Protocol[S_contra, RWAble_contra]):
102110
def __call__(
103111
self,
104-
f: StorageType,
112+
f: S_contra,
105113
k: str,
106114
v: RWAble_contra,
115+
/,
107116
*,
108117
_writer: Writer,
109118
dataset_kwargs: Mapping[str, Any],
@@ -116,6 +125,7 @@ def __call__(
116125
f: StorageType,
117126
k: str,
118127
v: RWAble_contra,
128+
/,
119129
*,
120130
dataset_kwargs: Mapping[str, Any],
121131
) -> None:

0 commit comments

Comments
 (0)