@@ -455,6 +455,113 @@ def test_write_indptr_dtype_override(store, sparse_format):
455455 np .testing .assert_array_equal (store ["X/indptr" ][...], X .indptr )
456456
457457
458+ @pytest .mark .parametrize (
459+ ("num_minor_axis" , "expected_dtype" ),
460+ [
461+ pytest .param (1 , np .dtype ("uint8" ), id = "one_col-expected_uint8_on_disk" ),
462+ pytest .param (
463+ np .iinfo (np .uint8 ).max ,
464+ np .dtype ("uint8" ),
465+ id = "max_np.uint8-matching_dtype_on_disk" ,
466+ ),
467+ pytest .param (
468+ np .iinfo (np .int8 ).max ,
469+ np .dtype ("uint8" ),
470+ id = "max_np.int8-uint8_on_disk" ,
471+ ),
472+ pytest .param (
473+ np .iinfo (np .uint16 ).max ,
474+ np .dtype ("uint16" ),
475+ id = "max_np.uint16-matching_dtype_on_disk" ,
476+ ),
477+ pytest .param (
478+ np .iinfo (np .int16 ).max ,
479+ np .dtype ("uint16" ),
480+ id = "max_np.int16-uint16_on_disk" ,
481+ ),
482+ pytest .param (
483+ np .iinfo (np .uint32 ).max ,
484+ np .dtype ("uint32" ),
485+ id = "max_np.uint32-matching_dtype_on_disk" ,
486+ ),
487+ pytest .param (
488+ np .iinfo (np .int32 ).max ,
489+ np .dtype ("uint32" ),
490+ id = "max_np.int32-uint32_on_disk" ,
491+ ),
492+ pytest .param (
493+ np .iinfo (np .uint8 ).max + 1 ,
494+ np .dtype ("uint16" ),
495+ id = "max_np.uint8_plus_one_cols-expected_uint16_on_disk" ,
496+ ),
497+ pytest .param (
498+ np .iinfo (np .uint16 ).max + 1 ,
499+ np .dtype ("uint32" ),
500+ id = "max_np.uint16_plus_one_cols-expected_uint32_on_disk" ,
501+ ),
502+ pytest .param (
503+ np .iinfo (np .uint32 ).max + 1 ,
504+ np .dtype ("uint64" ),
505+ id = "max_np.uint32_plus_one_cols-expected_uint64_on_disk" ,
506+ ),
507+ pytest .param (
508+ np .iinfo (np .int64 ).max + 1 ,
509+ np .dtype ("uint64" ),
510+ id = "max_np.int64_plus_one_cols-expected_uint64_on_disk" ,
511+ marks = pytest .mark .xfail (
512+ reason = "scipy sparse does not support bigger than max(int64) values in indices and there is no uint128."
513+ ),
514+ ),
515+ pytest .param (
516+ np .iinfo (np .uint64 ).max + 1 ,
517+ np .dtype ("uint64" ),
518+ id = "max_np.uint64_plus_one_cols-expected_uint64_on_disk" ,
519+ marks = pytest .mark .xfail (
520+ reason = "scipy sparse does not support bigger than max(int64) values in indices and there is no uint128."
521+ ),
522+ ),
523+ ],
524+ )
525+ @pytest .mark .parametrize ("format" , ["csr" , "csc" ])
526+ def test_write_indices_min (
527+ store : H5Group | ZarrGroup ,
528+ num_minor_axis : int ,
529+ expected_dtype : np .dtype ,
530+ format : Literal ["csr" , "csc" ],
531+ ):
532+ minor_axis_index = np .array ([num_minor_axis - 1 ])
533+ major_axis_index = np .array ([10 ])
534+ row_cols = (
535+ (minor_axis_index , major_axis_index )
536+ if format == "csc"
537+ else (major_axis_index , minor_axis_index )
538+ )
539+ shape = (num_minor_axis , 20 ) if format == "csc" else (20 , num_minor_axis )
540+ X = getattr (sparse , f"{ format } _array" )(
541+ (np .array ([10 ]), row_cols ),
542+ shape = shape ,
543+ )
544+ assert X .nnz == 1
545+ with ad .settings .override (write_csr_csc_indices_with_min_possible_dtype = True ):
546+ write_elem (store , "X" , X )
547+
548+ assert store ["X/indices" ].dtype == expected_dtype
549+ with ad .settings .override (use_sparse_array_on_read = True ):
550+ result = read_elem (store ["X" ])
551+ assert_equal (result .data , X .data )
552+ assert_equal (result .indices , X .indices )
553+ assert_equal (result .indptr , X .indptr )
554+ assert X .format == result .format
555+ assert result .shape == X .shape
556+ # != comparison converts to csr, which allocates a lot of memory or errors out with:
557+ # ValueError: array is too big; `arr.size * arr.dtype.itemsize` is larger than the maximum possible size.
558+ # Because the old, very large, minor axis is now the major axis and so either it fails to create or the indptr is very big.
559+ # The above tests should be enough to capture the desired equality checks so this is mostly for being extra sure.
560+ # See https://github.com/scipy/scipy/issues/23826
561+ if not (format == "csc" and num_minor_axis > np .iinfo (np .uint16 ).max + 1 ):
562+ assert (result != X ).nnz == 0
563+
564+
458565def test_io_spec_raw (store ):
459566 adata = gen_adata ((3 , 2 ), ** GEN_ADATA_NO_XARRAY_ARGS )
460567 adata .raw = adata .copy ()
0 commit comments