66from collections import Counter , defaultdict
77from collections .abc import Mapping
88from functools import partial , singledispatch , wraps
9+ from importlib .metadata import version
910from importlib .util import find_spec
1011from string import ascii_letters
1112from typing import TYPE_CHECKING
1415import numpy as np
1516import pandas as pd
1617import pytest
18+ from packaging .version import Version
1719from pandas .api .types import is_numeric_dtype
1820from scipy import sparse
1921
6163 * (pd .UInt8Dtype , pd .UInt16Dtype , pd .UInt32Dtype , pd .UInt64Dtype ),
6264 )
6365
66+ try :
67+ import fast_array_utils as _
68+ except ImportError :
69+ # dask natively supports sparray since https://github.com/dask/dask/pull/11750
70+ DASK_CAN_SPARRAY = Version (version ("dask" )) >= Version ("2025.3.0" )
71+ else : # fast-array-utils monkeypatches dask to support sparrays
72+ DASK_CAN_SPARRAY = True
73+
6474
6575DEFAULT_KEY_TYPES = (
6676 sparse .csr_matrix ,
@@ -628,8 +638,9 @@ def assert_equal_arrayview(
628638
629639@assert_equal .register (BaseCompressedSparseDataset )
630640@assert_equal .register (sparse .spmatrix )
641+ @assert_equal .register (CSArray )
631642def assert_equal_sparse (
632- a : BaseCompressedSparseDataset | sparse .spmatrix ,
643+ a : BaseCompressedSparseDataset | sparse .spmatrix | CSArray ,
633644 b : object ,
634645 * ,
635646 exact : bool = False ,
@@ -639,13 +650,6 @@ def assert_equal_sparse(
639650 assert_equal (b , a , exact = exact , elem_name = elem_name )
640651
641652
642- @assert_equal .register (CSArray )
643- def assert_equal_sparse_array (
644- a : CSArray , b : object , * , exact : bool = False , elem_name : str | None = None
645- ):
646- return assert_equal_sparse (a , b , exact = exact , elem_name = elem_name )
647-
648-
649653@assert_equal .register (CupySparseMatrix )
650654def assert_equal_cupy_sparse (
651655 a : CupySparseMatrix , b : object , * , exact : bool = False , elem_name : str | None = None
@@ -878,29 +882,37 @@ def _(a):
878882
879883
880884@singledispatch
881- def as_sparse_dask_array (a ) -> DaskArray :
882- import dask .array as da
883-
884- return da .from_array (sparse .csr_matrix (a ), chunks = _half_chunk_size (a .shape ))
885+ def _as_sparse_dask (
886+ a : NDArray | CSArray | CSMatrix | DaskArray , * , typ : type [CSArray | CSMatrix ]
887+ ) -> DaskArray :
888+ """Convert a to a sparse dask array, preserving sparse format and container (`cs{rc}_{array,matrix}`)."""
889+ raise NotImplementedError
885890
886891
887- @as_sparse_dask_array .register (CSMatrix )
888- def _ (a ) :
892+ @_as_sparse_dask .register (CSArray | CSMatrix | np . ndarray )
893+ def _ (a : CSArray | CSMatrix | NDArray , * , typ : type [ CSArray | CSMatrix ]) -> DaskArray :
889894 import dask .array as da
890895
891- return da .from_array (a , _half_chunk_size (a .shape ))
896+ return da .from_array (_as_sparse_dask_inner ( a , typ = typ ) , _half_chunk_size (a .shape ))
892897
893898
894- @as_sparse_dask_array .register (CSArray )
895- def _ (a ) :
896- import dask . array as da
899+ @_as_sparse_dask .register (DaskArray )
900+ def _ (a : DaskArray , * , typ : type [ CSArray | CSMatrix ]) -> DaskArray :
901+ return a . map_blocks ( _as_sparse_dask_inner , typ = typ , dtype = a . dtype , meta = typ (( 2 , 2 )))
897902
898- return da .from_array (sparse .csr_matrix (a ), _half_chunk_size (a .shape ))
899903
904+ def _as_sparse_dask_inner (
905+ a : NDArray | CSArray | CSMatrix , * , typ : type [CSArray | CSMatrix ]
906+ ) -> CSArray | CSMatrix :
907+ """Convert into a a sparse container that dask supports (or complain)."""
908+ if issubclass (typ , CSArray ) and not DASK_CAN_SPARRAY : # convert sparray to spmatrix
909+ msg = "Dask <2025.3 without fast-array-utils doesn’t support sparse arrays"
910+ raise TypeError (msg )
911+ return typ (a )
900912
901- @ as_sparse_dask_array . register ( DaskArray )
902- def _ ( a ):
903- return a . map_blocks ( sparse .csr_matrix )
913+
914+ as_sparse_dask_array = partial ( _as_sparse_dask , typ = sparse . csr_array )
915+ as_sparse_dask_matrix = partial ( _as_sparse_dask , typ = sparse .csr_matrix )
904916
905917
906918@singledispatch
@@ -949,11 +961,8 @@ def _(a):
949961# We should try and fix this upstream in dask/ cupy
950962@singledispatch
951963def as_cupy_sparse_dask_array (a , format = "csr" ):
952- memory_class = format_to_memory_class [format ]
953- cpu_da = as_sparse_dask_array (a )
954- return cpu_da .rechunk ((cpu_da .chunks [0 ], - 1 )).map_blocks (
955- memory_class , dtype = a .dtype , meta = memory_class (cpu_da ._meta )
956- )
964+ da = _as_sparse_dask (a , typ = format_to_memory_class [format ])
965+ return da .rechunk ((da .chunks [0 ], - 1 ))
957966
958967
959968@as_cupy_sparse_dask_array .register (CupyArray )
@@ -1003,7 +1012,7 @@ def as_cupy(val, typ=None):
10031012 if issubclass (typ , CupyArray ):
10041013 import cupy as cp
10051014
1006- if isinstance (val , CSMatrix ):
1015+ if isinstance (val , CSMatrix | CSArray ):
10071016 val = val .toarray ()
10081017 return cp .array (val )
10091018 elif issubclass (typ , CupyCSRMatrix ):
@@ -1059,7 +1068,14 @@ def shares_memory_sparse(x, y):
10591068
10601069DASK_MATRIX_PARAMS = [
10611070 pytest .param (as_dense_dask_array , id = "dense_dask_array" ),
1062- pytest .param (as_sparse_dask_array , id = "sparse_dask_array" ),
1071+ pytest .param (as_sparse_dask_matrix , id = "sparse_dask_matrix" ),
1072+ pytest .param (
1073+ as_sparse_dask_array ,
1074+ marks = pytest .mark .skipif (
1075+ not DASK_CAN_SPARRAY , reason = "Dask does not support sparrays"
1076+ ),
1077+ id = "sparse_dask_array" ,
1078+ ),
10631079]
10641080
10651081CUPY_MATRIX_PARAMS = [
0 commit comments