@@ -448,6 +448,113 @@ def test_write_indptr_dtype_override(store, sparse_format):
448448 np .testing .assert_array_equal (store ["X/indptr" ][...], X .indptr )
449449
450450
451+ @pytest .mark .parametrize (
452+ ("num_minor_axis" , "expected_dtype" ),
453+ [
454+ pytest .param (1 , np .dtype ("uint8" ), id = "one_col-expected_uint8_on_disk" ),
455+ pytest .param (
456+ np .iinfo (np .uint8 ).max ,
457+ np .dtype ("uint8" ),
458+ id = "max_np.uint8-matching_dtype_on_disk" ,
459+ ),
460+ pytest .param (
461+ np .iinfo (np .int8 ).max ,
462+ np .dtype ("uint8" ),
463+ id = "max_np.int8-uint8_on_disk" ,
464+ ),
465+ pytest .param (
466+ np .iinfo (np .uint16 ).max ,
467+ np .dtype ("uint16" ),
468+ id = "max_np.uint16-matching_dtype_on_disk" ,
469+ ),
470+ pytest .param (
471+ np .iinfo (np .int16 ).max ,
472+ np .dtype ("uint16" ),
473+ id = "max_np.int16-uint16_on_disk" ,
474+ ),
475+ pytest .param (
476+ np .iinfo (np .uint32 ).max ,
477+ np .dtype ("uint32" ),
478+ id = "max_np.uint32-matching_dtype_on_disk" ,
479+ ),
480+ pytest .param (
481+ np .iinfo (np .int32 ).max ,
482+ np .dtype ("uint32" ),
483+ id = "max_np.int32-uint32_on_disk" ,
484+ ),
485+ pytest .param (
486+ np .iinfo (np .uint8 ).max + 1 ,
487+ np .dtype ("uint16" ),
488+ id = "max_np.uint8_plus_one_cols-expected_uint16_on_disk" ,
489+ ),
490+ pytest .param (
491+ np .iinfo (np .uint16 ).max + 1 ,
492+ np .dtype ("uint32" ),
493+ id = "max_np.uint16_plus_one_cols-expected_uint32_on_disk" ,
494+ ),
495+ pytest .param (
496+ np .iinfo (np .uint32 ).max + 1 ,
497+ np .dtype ("uint64" ),
498+ id = "max_np.uint32_plus_one_cols-expected_uint64_on_disk" ,
499+ ),
500+ pytest .param (
501+ np .iinfo (np .int64 ).max + 1 ,
502+ np .dtype ("uint64" ),
503+ id = "max_np.int64_plus_one_cols-expected_uint64_on_disk" ,
504+ marks = pytest .mark .xfail (
505+ reason = "scipy sparse does not support bigger than max(int64) values in indices and there is no uint128."
506+ ),
507+ ),
508+ pytest .param (
509+ np .iinfo (np .uint64 ).max + 1 ,
510+ np .dtype ("uint64" ),
511+ id = "max_np.uint64_plus_one_cols-expected_uint64_on_disk" ,
512+ marks = pytest .mark .xfail (
513+ reason = "scipy sparse does not support bigger than max(int64) values in indices and there is no uint128."
514+ ),
515+ ),
516+ ],
517+ )
518+ @pytest .mark .parametrize ("format" , ["csr" , "csc" ])
519+ def test_write_indices_min (
520+ store : H5Group | ZarrGroup ,
521+ num_minor_axis : int ,
522+ expected_dtype : np .dtype ,
523+ format : Literal ["csr" , "csc" ],
524+ ):
525+ minor_axis_index = np .array ([num_minor_axis - 1 ])
526+ major_axis_index = np .array ([10 ])
527+ row_cols = (
528+ (minor_axis_index , major_axis_index )
529+ if format == "csc"
530+ else (major_axis_index , minor_axis_index )
531+ )
532+ shape = (num_minor_axis , 20 ) if format == "csc" else (20 , num_minor_axis )
533+ X = getattr (sparse , f"{ format } _array" )(
534+ (np .array ([10 ]), row_cols ),
535+ shape = shape ,
536+ )
537+ assert X .nnz == 1
538+ with ad .settings .override (write_csr_csc_indices_with_min_possible_dtype = True ):
539+ write_elem (store , "X" , X )
540+
541+ assert store ["X/indices" ].dtype == expected_dtype
542+ with ad .settings .override (use_sparse_array_on_read = True ):
543+ result = read_elem (store ["X" ])
544+ assert_equal (result .data , X .data )
545+ assert_equal (result .indices , X .indices )
546+ assert_equal (result .indptr , X .indptr )
547+ assert X .format == result .format
548+ assert result .shape == X .shape
549+ # != comparison converts to csr, which allocates a lot of memory or errors out with:
550+ # ValueError: array is too big; `arr.size * arr.dtype.itemsize` is larger than the maximum possible size.
551+ # Because the old, very large, minor axis is now the major axis and so either it fails to create or the indptr is very big.
552+ # The above tests should be enough to capture the desired equality checks so this is mostly for being extra sure.
553+ # See https://github.com/scipy/scipy/issues/23826
554+ if not (format == "csc" and num_minor_axis > np .iinfo (np .uint16 ).max + 1 ):
555+ assert (result != X ).nnz == 0
556+
557+
451558def test_io_spec_raw (store ):
452559 adata = gen_adata ((3 , 2 ), ** GEN_ADATA_NO_XARRAY_ARGS )
453560 adata .raw = adata .copy ()
0 commit comments