Skip to content

Commit c87dc0a

Browse files
fix: dont override dataset kwargs in loop for sparse (#2445)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a7ba562 commit c87dc0a

3 files changed

Lines changed: 24 additions & 2 deletions

File tree

docs/release-notes/2445.fix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make sure `indices`, `data`, and `indptr` have `zarr.config.set({'array.target_shard_size_bytes'})` applied instead of having one override the other's setting when writing sparse matrices. {user}`ilan-gold`

src/anndata/_io/specs/methods.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -804,9 +804,9 @@ def write_sparse_compressed(
804804
else:
805805
with zarr_v3_sharding(
806806
dataset_kwargs, format=f.metadata.zarr_format
807-
) as dataset_kwargs:
807+
) as dataset_kwargs_local:
808808
arr = g.create_array(
809-
attr_name, shape=attr.shape, dtype=dtype, **dataset_kwargs
809+
attr_name, shape=attr.shape, dtype=dtype, **dataset_kwargs_local
810810
)
811811
# see https://github.com/zarr-developers/zarr-python/discussions/2712
812812
arr[...] = attr[...]

tests/test_io_elementwise.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,27 @@ def test_write_auto_sharded_default_warns(tmp_path: Path):
948948
adata.write_zarr(path)
949949

950950

951+
@pytest.mark.zarr_io
952+
@pytest.mark.skipif(
953+
Version(version("zarr")) < Version("3.1.4"),
954+
reason="autosharding with chosen size was not available",
955+
)
956+
def test_write_auto_sharded_size_sparse(tmp_path: Path):
957+
path = "memory://check_shards.zarr"
958+
z = zarr.open(path)
959+
mat = sparse.random(
960+
1000, 1000, density=0.5, format="csr", random_state=np.random.default_rng(42)
961+
)
962+
ad.io.write_elem(z, "two_shards_per_sub_element", mat)
963+
# i.e., there are at most two shards since one shard will contain two chunks,
964+
# and the other the last elements, since the target size is 1GB uncompressed.
965+
for sub_element in ["indices", "data", "indptr"]:
966+
assert (
967+
z["two_shards_per_sub_element"][sub_element].shape[0]
968+
/ z["two_shards_per_sub_element"][sub_element].shards[0]
969+
) < 2, sub_element
970+
971+
951972
@pytest.mark.zarr_io
952973
def test_write_auto_sharded_does_not_override(tmp_path: Path):
953974
z = open_write_group(tmp_path / "arr.zarr", zarr_format=3)

0 commit comments

Comments
 (0)