Skip to content

Commit 1af48ee

Browse files
ilan-goldmeeseeksmachine
authored andcommitted
Backport PR scverse#2159: perf: add ability to write downcasted indices
1 parent 172c4db commit 1af48ee

5 files changed

Lines changed: 135 additions & 0 deletions

File tree

docs/release-notes/2159.perf.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add a `write_csr_csc_indices_with_min_possible_dtype` option to {attr}`anndata.settings` to enable downcasting of the `indices` of csr and csc matrices to a smaller dtype when writing. For example, if your csr matrix only has 30000 columns, then you can write out the `indices` of that matrix as `uint16` instead of `int64`. {user}`ilan-gold`

src/anndata/_io/specs/methods.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,24 @@ def write_sparse_compressed(
738738
for attr_name in ["data", "indices", "indptr"]:
739739
attr = getattr(value, attr_name)
740740
dtype = indptr_dtype if attr_name == "indptr" else attr.dtype
741+
if (
742+
attr_name == "indices"
743+
and settings.write_csr_csc_indices_with_min_possible_dtype
744+
):
745+
# np.min_scalar_type can return things like np.ulonglong which zarr doesn't understand
746+
# and I find this clearer as to what the result type is i.e., unsigned or signed.
747+
# For example `np.iinfo(np.uint16).max + 1` could be either `uint32` or `int32`,
748+
# and there's nothing in numpy's docs disallowing this output to change.
749+
if (minor_axis_size := value.shape[value.format == "csr"]) <= np.iinfo(
750+
np.uint8
751+
).max:
752+
dtype = np.dtype("uint8")
753+
elif minor_axis_size <= np.iinfo(np.uint16).max:
754+
dtype = np.dtype("uint16")
755+
elif minor_axis_size <= np.iinfo(np.uint32).max:
756+
dtype = np.dtype("uint32")
757+
elif minor_axis_size <= np.iinfo(np.uint64).max:
758+
dtype = np.dtype("uint64")
741759
if isinstance(f, H5Group) or is_zarr_v2():
742760
g.create_dataset(
743761
attr_name, data=attr, shape=attr.shape, dtype=dtype, **dataset_kwargs

src/anndata/_settings.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,14 @@ def validate_sparse_settings(val: Any, settings: SettingsManager) -> None:
503503
get_from_env=check_and_get_bool,
504504
)
505505

506+
settings.register(
507+
"write_csr_csc_indices_with_min_possible_dtype",
508+
default_value=False,
509+
description="Write a csr or csc matrix with the minimum possible data type for `indices`, always unsigned integer.",
510+
validate=validate_bool,
511+
get_from_env=check_and_get_bool,
512+
)
513+
506514
settings.register(
507515
"auto_shard_zarr_v3",
508516
default_value=False,

src/anndata/_settings.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class _AnnDataSettingsManager(SettingsManager):
4646
use_sparse_array_on_read: bool = False
4747
min_rows_for_chunked_h5_copy: int = 1000
4848
disallow_forward_slash_in_h5ad: bool = False
49+
write_csr_csc_indices_with_min_possible_dtype: bool = False
4950
auto_shard_zarr_v3: bool = False
5051

5152
settings: _AnnDataSettingsManager

tests/test_io_elementwise.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,113 @@ def test_write_indptr_dtype_override(store, sparse_format):
455455
np.testing.assert_array_equal(store["X/indptr"][...], X.indptr)
456456

457457

458+
@pytest.mark.parametrize(
459+
("num_minor_axis", "expected_dtype"),
460+
[
461+
pytest.param(1, np.dtype("uint8"), id="one_col-expected_uint8_on_disk"),
462+
pytest.param(
463+
np.iinfo(np.uint8).max,
464+
np.dtype("uint8"),
465+
id="max_np.uint8-matching_dtype_on_disk",
466+
),
467+
pytest.param(
468+
np.iinfo(np.int8).max,
469+
np.dtype("uint8"),
470+
id="max_np.int8-uint8_on_disk",
471+
),
472+
pytest.param(
473+
np.iinfo(np.uint16).max,
474+
np.dtype("uint16"),
475+
id="max_np.uint16-matching_dtype_on_disk",
476+
),
477+
pytest.param(
478+
np.iinfo(np.int16).max,
479+
np.dtype("uint16"),
480+
id="max_np.int16-uint16_on_disk",
481+
),
482+
pytest.param(
483+
np.iinfo(np.uint32).max,
484+
np.dtype("uint32"),
485+
id="max_np.uint32-matching_dtype_on_disk",
486+
),
487+
pytest.param(
488+
np.iinfo(np.int32).max,
489+
np.dtype("uint32"),
490+
id="max_np.int32-uint32_on_disk",
491+
),
492+
pytest.param(
493+
np.iinfo(np.uint8).max + 1,
494+
np.dtype("uint16"),
495+
id="max_np.uint8_plus_one_cols-expected_uint16_on_disk",
496+
),
497+
pytest.param(
498+
np.iinfo(np.uint16).max + 1,
499+
np.dtype("uint32"),
500+
id="max_np.uint16_plus_one_cols-expected_uint32_on_disk",
501+
),
502+
pytest.param(
503+
np.iinfo(np.uint32).max + 1,
504+
np.dtype("uint64"),
505+
id="max_np.uint32_plus_one_cols-expected_uint64_on_disk",
506+
),
507+
pytest.param(
508+
np.iinfo(np.int64).max + 1,
509+
np.dtype("uint64"),
510+
id="max_np.int64_plus_one_cols-expected_uint64_on_disk",
511+
marks=pytest.mark.xfail(
512+
reason="scipy sparse does not support bigger than max(int64) values in indices and there is no uint128."
513+
),
514+
),
515+
pytest.param(
516+
np.iinfo(np.uint64).max + 1,
517+
np.dtype("uint64"),
518+
id="max_np.uint64_plus_one_cols-expected_uint64_on_disk",
519+
marks=pytest.mark.xfail(
520+
reason="scipy sparse does not support bigger than max(int64) values in indices and there is no uint128."
521+
),
522+
),
523+
],
524+
)
525+
@pytest.mark.parametrize("format", ["csr", "csc"])
526+
def test_write_indices_min(
527+
store: H5Group | ZarrGroup,
528+
num_minor_axis: int,
529+
expected_dtype: np.dtype,
530+
format: Literal["csr", "csc"],
531+
):
532+
minor_axis_index = np.array([num_minor_axis - 1])
533+
major_axis_index = np.array([10])
534+
row_cols = (
535+
(minor_axis_index, major_axis_index)
536+
if format == "csc"
537+
else (major_axis_index, minor_axis_index)
538+
)
539+
shape = (num_minor_axis, 20) if format == "csc" else (20, num_minor_axis)
540+
X = getattr(sparse, f"{format}_array")(
541+
(np.array([10]), row_cols),
542+
shape=shape,
543+
)
544+
assert X.nnz == 1
545+
with ad.settings.override(write_csr_csc_indices_with_min_possible_dtype=True):
546+
write_elem(store, "X", X)
547+
548+
assert store["X/indices"].dtype == expected_dtype
549+
with ad.settings.override(use_sparse_array_on_read=True):
550+
result = read_elem(store["X"])
551+
assert_equal(result.data, X.data)
552+
assert_equal(result.indices, X.indices)
553+
assert_equal(result.indptr, X.indptr)
554+
assert X.format == result.format
555+
assert result.shape == X.shape
556+
# != comparison converts to csr, which allocates a lot of memory or errors out with:
557+
# ValueError: array is too big; `arr.size * arr.dtype.itemsize` is larger than the maximum possible size.
558+
# Because the old, very large, minor axis is now the major axis and so either it fails to create or the indptr is very big.
559+
# The above tests should be enough to capture the desired equality checks so this is mostly for being extra sure.
560+
# See https://github.com/scipy/scipy/issues/23826
561+
if not (format == "csc" and num_minor_axis > np.iinfo(np.uint16).max + 1):
562+
assert (result != X).nnz == 0
563+
564+
458565
def test_io_spec_raw(store):
459566
adata = gen_adata((3, 2), **GEN_ADATA_NO_XARRAY_ARGS)
460567
adata.raw = adata.copy()

0 commit comments

Comments
 (0)