Skip to content

Commit 7ed86f7

Browse files
perf: add ability to write downcasted indices (#2159)
Co-authored-by: Philipp A. <flying-sheep@web.de>
1 parent 2a2c0e3 commit 7ed86f7

5 files changed

Lines changed: 135 additions & 0 deletions

File tree

docs/release-notes/2159.perf.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add a `write_csr_csc_indices_with_min_possible_dtype` option to {attr}`anndata.settings` to enable downcasting of the `indices` of csr and csc matrices to a smaller dtype when writing. For example, if your csr matrix only has 30000 columns, then you can write out the `indices` of that matrix as `uint16` instead of `int64`. {user}`ilan-gold`

src/anndata/_io/specs/methods.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,24 @@ def write_sparse_compressed(
731731
for attr_name in ["data", "indices", "indptr"]:
732732
attr = getattr(value, attr_name)
733733
dtype = indptr_dtype if attr_name == "indptr" else attr.dtype
734+
if (
735+
attr_name == "indices"
736+
and settings.write_csr_csc_indices_with_min_possible_dtype
737+
):
738+
# np.min_scalar_type can return things like np.ulonglong which zarr doesn't understand
739+
# and I find this clearer as to what the result type is i.e., unsigned or signed.
740+
# For example `np.iinfo(np.uint16).max + 1` could be either `uint32` or `int32`,
741+
# and there's nothing in numpy's docs disallowing this output to change.
742+
if (minor_axis_size := value.shape[value.format == "csr"]) <= np.iinfo(
743+
np.uint8
744+
).max:
745+
dtype = np.dtype("uint8")
746+
elif minor_axis_size <= np.iinfo(np.uint16).max:
747+
dtype = np.dtype("uint16")
748+
elif minor_axis_size <= np.iinfo(np.uint32).max:
749+
dtype = np.dtype("uint32")
750+
elif minor_axis_size <= np.iinfo(np.uint64).max:
751+
dtype = np.dtype("uint64")
734752
if isinstance(f, H5Group) or is_zarr_v2():
735753
g.create_dataset(
736754
attr_name, data=attr, shape=attr.shape, dtype=dtype, **dataset_kwargs

src/anndata/_settings.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,14 @@ def validate_sparse_settings(val: Any, settings: SettingsManager) -> None:
516516
get_from_env=check_and_get_bool,
517517
)
518518

519+
settings.register(
520+
"write_csr_csc_indices_with_min_possible_dtype",
521+
default_value=False,
522+
description="Write a csr or csc matrix with the minimum possible data type for `indices`, always unsigned integer.",
523+
validate=validate_bool,
524+
get_from_env=check_and_get_bool,
525+
)
526+
519527
settings.register(
520528
"auto_shard_zarr_v3",
521529
default_value=False,

src/anndata/_settings.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class _AnnDataSettingsManager(SettingsManager):
4444
use_sparse_array_on_read: bool = False
4545
min_rows_for_chunked_h5_copy: int = 1000
4646
disallow_forward_slash_in_h5ad: bool = False
47+
write_csr_csc_indices_with_min_possible_dtype: bool = False
4748
auto_shard_zarr_v3: bool = False
4849

4950
settings: _AnnDataSettingsManager

tests/test_io_elementwise.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,113 @@ def test_write_indptr_dtype_override(store, sparse_format):
448448
np.testing.assert_array_equal(store["X/indptr"][...], X.indptr)
449449

450450

451+
@pytest.mark.parametrize(
452+
("num_minor_axis", "expected_dtype"),
453+
[
454+
pytest.param(1, np.dtype("uint8"), id="one_col-expected_uint8_on_disk"),
455+
pytest.param(
456+
np.iinfo(np.uint8).max,
457+
np.dtype("uint8"),
458+
id="max_np.uint8-matching_dtype_on_disk",
459+
),
460+
pytest.param(
461+
np.iinfo(np.int8).max,
462+
np.dtype("uint8"),
463+
id="max_np.int8-uint8_on_disk",
464+
),
465+
pytest.param(
466+
np.iinfo(np.uint16).max,
467+
np.dtype("uint16"),
468+
id="max_np.uint16-matching_dtype_on_disk",
469+
),
470+
pytest.param(
471+
np.iinfo(np.int16).max,
472+
np.dtype("uint16"),
473+
id="max_np.int16-uint16_on_disk",
474+
),
475+
pytest.param(
476+
np.iinfo(np.uint32).max,
477+
np.dtype("uint32"),
478+
id="max_np.uint32-matching_dtype_on_disk",
479+
),
480+
pytest.param(
481+
np.iinfo(np.int32).max,
482+
np.dtype("uint32"),
483+
id="max_np.int32-uint32_on_disk",
484+
),
485+
pytest.param(
486+
np.iinfo(np.uint8).max + 1,
487+
np.dtype("uint16"),
488+
id="max_np.uint8_plus_one_cols-expected_uint16_on_disk",
489+
),
490+
pytest.param(
491+
np.iinfo(np.uint16).max + 1,
492+
np.dtype("uint32"),
493+
id="max_np.uint16_plus_one_cols-expected_uint32_on_disk",
494+
),
495+
pytest.param(
496+
np.iinfo(np.uint32).max + 1,
497+
np.dtype("uint64"),
498+
id="max_np.uint32_plus_one_cols-expected_uint64_on_disk",
499+
),
500+
pytest.param(
501+
np.iinfo(np.int64).max + 1,
502+
np.dtype("uint64"),
503+
id="max_np.int64_plus_one_cols-expected_uint64_on_disk",
504+
marks=pytest.mark.xfail(
505+
reason="scipy sparse does not support bigger than max(int64) values in indices and there is no uint128."
506+
),
507+
),
508+
pytest.param(
509+
np.iinfo(np.uint64).max + 1,
510+
np.dtype("uint64"),
511+
id="max_np.uint64_plus_one_cols-expected_uint64_on_disk",
512+
marks=pytest.mark.xfail(
513+
reason="scipy sparse does not support bigger than max(int64) values in indices and there is no uint128."
514+
),
515+
),
516+
],
517+
)
518+
@pytest.mark.parametrize("format", ["csr", "csc"])
519+
def test_write_indices_min(
520+
store: H5Group | ZarrGroup,
521+
num_minor_axis: int,
522+
expected_dtype: np.dtype,
523+
format: Literal["csr", "csc"],
524+
):
525+
minor_axis_index = np.array([num_minor_axis - 1])
526+
major_axis_index = np.array([10])
527+
row_cols = (
528+
(minor_axis_index, major_axis_index)
529+
if format == "csc"
530+
else (major_axis_index, minor_axis_index)
531+
)
532+
shape = (num_minor_axis, 20) if format == "csc" else (20, num_minor_axis)
533+
X = getattr(sparse, f"{format}_array")(
534+
(np.array([10]), row_cols),
535+
shape=shape,
536+
)
537+
assert X.nnz == 1
538+
with ad.settings.override(write_csr_csc_indices_with_min_possible_dtype=True):
539+
write_elem(store, "X", X)
540+
541+
assert store["X/indices"].dtype == expected_dtype
542+
with ad.settings.override(use_sparse_array_on_read=True):
543+
result = read_elem(store["X"])
544+
assert_equal(result.data, X.data)
545+
assert_equal(result.indices, X.indices)
546+
assert_equal(result.indptr, X.indptr)
547+
assert X.format == result.format
548+
assert result.shape == X.shape
549+
# != comparison converts to csr, which allocates a lot of memory or errors out with:
550+
# ValueError: array is too big; `arr.size * arr.dtype.itemsize` is larger than the maximum possible size.
551+
# Because the old, very large, minor axis is now the major axis and so either it fails to create or the indptr is very big.
552+
# The above tests should be enough to capture the desired equality checks so this is mostly for being extra sure.
553+
# See https://github.com/scipy/scipy/issues/23826
554+
if not (format == "csc" and num_minor_axis > np.iinfo(np.uint16).max + 1):
555+
assert (result != X).nnz == 0
556+
557+
451558
def test_io_spec_raw(store):
452559
adata = gen_adata((3, 2), **GEN_ADATA_NO_XARRAY_ARGS)
453560
adata.raw = adata.copy()

0 commit comments

Comments
 (0)