@@ -62,24 +62,18 @@ def instance_class(
6262 ) -> CSBase :
6363 return cls .cls ((data , indices , indptr ), shape , copy = False )
6464
65- def __init__ (
66- self ,
67- ndim : int ,
68- * ,
69- dtype_data : nbtypes .Type ,
70- dtype_indices : nbtypes .Type ,
71- dtype_indptr : nbtypes .Type ,
72- ) -> None :
73- self .dtype = nbtypes .DType (dtype_data )
74- self .data = nbtypes .Array (dtype_data , 1 , "A" )
75- self .indices = nbtypes .Array (dtype_indices , 1 , "A" )
76- self .indptr = nbtypes .Array (dtype_indptr , 1 , "A" )
65+ def __init__ (self , ndim : int , * , dtype : nbtypes .Type , dtype_ind : nbtypes .Type ) -> None :
66+ self .dtype = nbtypes .DType (dtype )
67+ self .dtype_ind = nbtypes .DType (dtype_ind )
68+ self .data = nbtypes .Array (dtype , 1 , "A" )
69+ self .indices = nbtypes .Array (dtype_ind , 1 , "A" )
70+ self .indptr = nbtypes .Array (dtype_ind , 1 , "A" )
7771 self .shape = nbtypes .UniTuple (nbtypes .intp , ndim )
7872 super ().__init__ (self .name )
7973
8074 @property
8175 def key (self ) -> tuple [str | nbtypes .Type , ...]:
82- return (self .name , self .dtype , self .indices . dtype , self . indptr . dtype )
76+ return (self .name , self .dtype , self .dtype_ind )
8377
8478
8579# make data model attributes available in numba functions
@@ -88,13 +82,15 @@ def key(self) -> tuple[str | nbtypes.Type, ...]:
8882
8983
9084def make_typeof_fn (typ : type [CSType ]) -> Callable [[CSBase , _TypeofContext ], CSType ]:
85+ """Create a `typeof` function that maps a scipy matrix/array type to a numba `Type`."""
86+
9187 def typeof (val : CSBase , c : _TypeofContext ) -> CSType :
88+ if val .indptr .dtype != val .indices .dtype :
89+ msg = "indptr and indices must have the same dtype"
90+ raise TypeError (msg )
9291 data = cast ("nbtypes.Array" , typeof_impl (val .data , c ))
93- indices = cast ("nbtypes.Array" , typeof_impl (val .indices , c ))
9492 indptr = cast ("nbtypes.Array" , typeof_impl (val .indptr , c ))
95- return typ (
96- val .ndim , dtype_data = data .dtype , dtype_indices = indices .dtype , dtype_indptr = indptr .dtype
97- )
93+ return typ (val .ndim , dtype = data .dtype , dtype_ind = indptr .dtype )
9894
9995 return typeof
10096
@@ -106,6 +102,11 @@ def typeof(val: CSBase, c: _TypeofContext) -> CSType:
106102
107103
108104class CSModel (_Base ):
105+ """Numba data model for compressed sparse matrices.
106+
107+ This is the class that is used by numba to lower the array types.
108+ """
109+
109110 def __init__ (self , dmm : DataModelManager , fe_type : CSType ) -> None :
110111 members = [
111112 ("data" , fe_type .data ),
@@ -116,6 +117,7 @@ def __init__(self, dmm: DataModelManager, fe_type: CSType) -> None:
116117 super ().__init__ (dmm , fe_type , members )
117118
118119
120+ # create all the actual types and data models
119121CLASSES : Sequence [type [CSBase ]] = [
120122 sparse .csr_matrix ,
121123 sparse .csc_matrix ,
0 commit comments