Skip to content

Commit ebc890f

Browse files
authored
Add a scipy.sparse numba extension (#73)
1 parent 104bf1c commit ebc890f

30 files changed

+706
-72
lines changed

Diff for: docs/conf.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@
7373
("py:class", "ToDType"),
7474
("py:class", "testing.fast_array_utils._array_type.Arr"),
7575
("py:class", "testing.fast_array_utils._array_type.Inner"),
76-
("py:class", "_DTypeLikeFloat32"),
77-
("py:class", "_DTypeLikeFloat64"),
76+
("py:class", "_DTypeLikeNum"),
7877
]
7978

8079
# Options for HTML output

Diff for: pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ optional-dependencies.sparse = [ "scipy>=1.8" ]
3434
optional-dependencies.test = [
3535
"anndata",
3636
"fast-array-utils[accel,test-min]",
37+
"numcodecs<0.16", # zarr 2 needs this
3738
"zarr<3", # anndata needs this
3839
]
3940
optional-dependencies.test-min = [
@@ -138,6 +139,8 @@ doctest_subpackage_requires = [
138139
"src/fast_array_utils/conv/scipy/* = scipy",
139140
"src/fast_array_utils/conv/scipy/_to_dense.py = numba",
140141
"src/fast_array_utils/stats/* = numba",
142+
"src/fast_array_utils/_plugins/dask.py = dask",
143+
"src/fast_array_utils/_plugins/numba_sparse.py = numba;scipy",
141144
]
142145
filterwarnings = [
143146
"error",

Diff for: src/fast_array_utils/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515

1616
from __future__ import annotations
1717

18-
from . import _patches, conv, stats, types
18+
from . import _plugins, conv, stats, types
1919

2020

2121
__all__ = ["conv", "stats", "types"]
2222

23-
_patches.patch_dask()
23+
_plugins.patch_dask()
24+
_plugins.register_numba_sparse()

Diff for: src/fast_array_utils/_plugins/__init__.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# SPDX-License-Identifier: MPL-2.0
2+
from __future__ import annotations
3+
4+
5+
__all__ = ["patch_dask", "register_numba_sparse"]
6+
7+
8+
def patch_dask() -> None:
9+
r"""Patch Dask Arrays so it supports `scipy.sparse.sparray`\ s."""
10+
try:
11+
from .dask import patch
12+
except ImportError:
13+
pass
14+
else:
15+
patch()
16+
17+
18+
def register_numba_sparse() -> None:
19+
r"""Register `scipy.sparse.sp{matrix,array}`\ s with Numba.
20+
21+
This makes it cleaner to write numba functions operating on these types.
22+
"""
23+
try:
24+
from .numba_sparse import register
25+
except ImportError:
26+
pass
27+
else:
28+
register()

Diff for: src/fast_array_utils/_patches.py renamed to src/fast_array_utils/_plugins/dask.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,18 @@
33

44
import numpy as np
55

6+
# Other lookup candidates: tensordot_lookup and take_lookup
7+
from dask.array.dispatch import concatenate_lookup
8+
from scipy.sparse import sparray, spmatrix
9+
610

711
# TODO(flying-sheep): upstream
812
# https://github.com/dask/dask/issues/11749
9-
def patch_dask() -> None: # pragma: no cover
13+
def patch() -> None: # pragma: no cover
1014
"""Patch dask to support sparse arrays.
1115
1216
See <https://github.com/dask/dask/blob/4d71629d1f22ced0dd780919f22e70a642ec6753/dask/array/backends.py#L212-L232>
1317
"""
14-
try:
15-
# Other lookup candidates: tensordot_lookup and take_lookup
16-
from dask.array.dispatch import concatenate_lookup
17-
from scipy.sparse import sparray, spmatrix
18-
except ImportError:
19-
return # No need to patch if dask or scipy is not installed
20-
2118
# Avoid patch if already patched or upstream support has been added
2219
if concatenate_lookup.dispatch(sparray) is not np.concatenate:
2320
return

Diff for: src/fast_array_utils/_plugins/numba_sparse.py

+270
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
# SPDX-License-Identifier: MPL-2.0
2+
"""Numba support for sparse arrays and matrices."""
3+
4+
# taken from https://github.com/numba/numba-scipy/blob/release0.4/numba_scipy/sparse.py
5+
# See https://numba.pydata.org/numba-doc/dev/extending/
6+
from __future__ import annotations
7+
8+
from typing import TYPE_CHECKING, cast
9+
10+
import numba.core.types as nbtypes
11+
import numpy as np
12+
from numba.core import cgutils
13+
from numba.core.imputils import impl_ret_borrowed
14+
from numba.extending import (
15+
NativeValue,
16+
box,
17+
intrinsic,
18+
make_attribute_wrapper,
19+
models,
20+
overload,
21+
overload_attribute,
22+
overload_method,
23+
register_model,
24+
typeof_impl,
25+
unbox,
26+
)
27+
from scipy import sparse
28+
29+
30+
if TYPE_CHECKING:
31+
from collections.abc import Callable, Mapping, Sequence
32+
from typing import Any, ClassVar, Literal
33+
34+
from llvmlite.ir import IRBuilder, Value
35+
from numba.core.base import BaseContext
36+
from numba.core.datamodel.manager import DataModelManager
37+
from numba.core.extending import BoxContext, TypingContext, UnboxContext
38+
from numba.core.typing.templates import Signature
39+
from numba.core.typing.typeof import _TypeofContext
40+
from numpy.typing import NDArray
41+
42+
from fast_array_utils.types import CSBase
43+
44+
45+
class CSType(nbtypes.Type):
46+
"""A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`.
47+
48+
This is an abstract base class for the actually used, registered types in `TYPES` below.
49+
It collects information about the type (e.g. field dtypes) for later use in the data model.
50+
"""
51+
52+
name: ClassVar[Literal["csr_matrix", "csc_matrix", "csr_array", "csc_array"]]
53+
cls: ClassVar[type[CSBase]]
54+
55+
@classmethod
56+
def instance_class(
57+
cls,
58+
data: NDArray[np.number[Any]],
59+
indices: NDArray[np.integer[Any]],
60+
indptr: NDArray[np.integer[Any]],
61+
shape: tuple[int, int], # actually tuple[int, ...] for sparray subclasses
62+
) -> CSBase:
63+
return cls.cls((data, indices, indptr), shape, copy=False)
64+
65+
def __init__(self, ndim: int, *, dtype: nbtypes.Type, dtype_ind: nbtypes.Type) -> None:
66+
self.dtype = nbtypes.DType(dtype)
67+
self.dtype_ind = nbtypes.DType(dtype_ind)
68+
self.data = nbtypes.Array(dtype, 1, "A")
69+
self.indices = nbtypes.Array(dtype_ind, 1, "A")
70+
self.indptr = nbtypes.Array(dtype_ind, 1, "A")
71+
self.shape = nbtypes.UniTuple(nbtypes.intp, ndim)
72+
super().__init__(self.name)
73+
74+
@property
75+
def key(self) -> tuple[str | nbtypes.Type, ...]:
76+
return (self.name, self.dtype, self.dtype_ind)
77+
78+
79+
# make data model attributes available in numba functions
80+
for attr in ["data", "indices", "indptr", "shape"]:
81+
make_attribute_wrapper(CSType, attr, attr)
82+
83+
84+
def make_typeof_fn(typ: type[CSType]) -> Callable[[CSBase, _TypeofContext], CSType]:
85+
"""Create a `typeof` function that maps a scipy matrix/array type to a numba `Type`."""
86+
87+
def typeof(val: CSBase, c: _TypeofContext) -> CSType:
88+
if val.indptr.dtype != val.indices.dtype: # pragma: no cover
89+
msg = "indptr and indices must have the same dtype"
90+
raise TypeError(msg)
91+
data = cast("nbtypes.Array", typeof_impl(val.data, c))
92+
indptr = cast("nbtypes.Array", typeof_impl(val.indptr, c))
93+
return typ(val.ndim, dtype=data.dtype, dtype_ind=indptr.dtype)
94+
95+
return typeof
96+
97+
98+
if TYPE_CHECKING:
99+
_CSModelBase = models.StructModel[CSType]
100+
else:
101+
_CSModelBase = models.StructModel
102+
103+
104+
class CSModel(_CSModelBase):
105+
"""Numba data model for compressed sparse matrices.
106+
107+
This is the class that is used by numba to lower the array types.
108+
"""
109+
110+
def __init__(self, dmm: DataModelManager, fe_type: CSType) -> None:
111+
members = [
112+
("data", fe_type.data),
113+
("indices", fe_type.indices),
114+
("indptr", fe_type.indptr),
115+
("shape", fe_type.shape),
116+
]
117+
super().__init__(dmm, fe_type, members)
118+
119+
120+
# create all the actual types and data models
121+
CLASSES: Sequence[type[CSBase]] = [
122+
sparse.csr_matrix,
123+
sparse.csc_matrix,
124+
sparse.csr_array,
125+
sparse.csc_array,
126+
]
127+
TYPES: Sequence[type[CSType]] = [
128+
type(f"{cls.__name__}Type", (CSType,), {"cls": cls, "name": cls.__name__}) for cls in CLASSES
129+
]
130+
TYPEOF_FUNCS: Mapping[type[CSBase], Callable[[CSBase, _TypeofContext], CSType]] = {
131+
typ.cls: make_typeof_fn(typ) for typ in TYPES
132+
}
133+
MODELS: Mapping[type[CSType], type[CSModel]] = {
134+
typ: type(f"{typ.cls.__name__}Model", (CSModel,), {}) for typ in TYPES
135+
}
136+
137+
138+
def unbox_matrix(typ: CSType, obj: Value, c: UnboxContext) -> NativeValue:
139+
"""Convert a Python cs{rc}_{matrix,array} to a Numba value."""
140+
struct_proxy_cls = cgutils.create_struct_proxy(typ)
141+
struct_ptr = struct_proxy_cls(c.context, c.builder)
142+
143+
data = c.pyapi.object_getattr_string(obj, "data")
144+
indices = c.pyapi.object_getattr_string(obj, "indices")
145+
indptr = c.pyapi.object_getattr_string(obj, "indptr")
146+
shape = c.pyapi.object_getattr_string(obj, "shape")
147+
148+
struct_ptr.data = c.unbox(typ.data, data).value
149+
struct_ptr.indices = c.unbox(typ.indices, indices).value
150+
struct_ptr.indptr = c.unbox(typ.indptr, indptr).value
151+
struct_ptr.shape = c.unbox(typ.shape, shape).value
152+
153+
c.pyapi.decref(data)
154+
c.pyapi.decref(indices)
155+
c.pyapi.decref(indptr)
156+
c.pyapi.decref(shape)
157+
158+
is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit)
159+
is_error = c.builder.load(is_error_ptr)
160+
161+
return NativeValue(struct_ptr._getvalue(), is_error=is_error) # noqa: SLF001
162+
163+
164+
def box_matrix(typ: CSType, val: NativeValue, c: BoxContext) -> Value:
165+
"""Convert numba value into a Python cs{rc}_{matrix,array}."""
166+
struct_proxy_cls = cgutils.create_struct_proxy(typ)
167+
struct_ptr = struct_proxy_cls(c.context, c.builder, value=val)
168+
169+
data_obj = c.box(typ.data, struct_ptr.data)
170+
indices_obj = c.box(typ.indices, struct_ptr.indices)
171+
indptr_obj = c.box(typ.indptr, struct_ptr.indptr)
172+
shape_obj = c.box(typ.shape, struct_ptr.shape)
173+
174+
c.pyapi.incref(data_obj)
175+
c.pyapi.incref(indices_obj)
176+
c.pyapi.incref(indptr_obj)
177+
c.pyapi.incref(shape_obj)
178+
179+
cls_obj = c.pyapi.unserialize(c.pyapi.serialize_object(typ.instance_class))
180+
obj = c.pyapi.call_function_objargs(cls_obj, (data_obj, indices_obj, indptr_obj, shape_obj))
181+
182+
c.pyapi.decref(data_obj)
183+
c.pyapi.decref(indices_obj)
184+
c.pyapi.decref(indptr_obj)
185+
c.pyapi.decref(shape_obj)
186+
187+
return obj
188+
189+
190+
# See https://numba.readthedocs.io/en/stable/extending/overloading-guide.html
191+
@overload(np.shape)
192+
def overload_sparse_shape(x: CSType) -> None | Callable[[CSType], nbtypes.UniTuple]:
193+
if not isinstance(x, CSType): # pragma: no cover
194+
return None
195+
196+
# nopython code:
197+
def shape(x: CSType) -> nbtypes.UniTuple: # pragma: no cover
198+
return x.shape
199+
200+
return shape
201+
202+
203+
@overload_attribute(CSType, "ndim")
204+
def overload_sparse_ndim(inst: CSType) -> None | Callable[[CSType], int]:
205+
if not isinstance(inst, CSType): # pragma: no cover
206+
return None
207+
208+
# nopython code:
209+
def ndim(inst: CSType) -> int: # pragma: no cover
210+
return len(inst.shape)
211+
212+
return ndim
213+
214+
215+
@intrinsic
216+
def _sparse_copy(
217+
typingctx: TypingContext, # noqa: ARG001
218+
inst: CSType,
219+
data: nbtypes.Array, # noqa: ARG001
220+
indices: nbtypes.Array, # noqa: ARG001
221+
indptr: nbtypes.Array, # noqa: ARG001
222+
shape: nbtypes.UniTuple, # noqa: ARG001
223+
) -> tuple[Signature, Callable[..., NativeValue]]:
224+
def _construct(
225+
context: BaseContext,
226+
builder: IRBuilder,
227+
sig: Signature,
228+
args: tuple[Value, Value, Value, Value, Value],
229+
) -> NativeValue:
230+
struct_proxy_cls = cgutils.create_struct_proxy(sig.return_type)
231+
struct = struct_proxy_cls(context, builder)
232+
_, data, indices, indptr, shape = args
233+
struct.data = data
234+
struct.indices = indices
235+
struct.indptr = indptr
236+
struct.shape = shape
237+
return impl_ret_borrowed(
238+
context,
239+
builder,
240+
sig.return_type,
241+
struct._getvalue(), # noqa: SLF001
242+
)
243+
244+
sig = inst(inst, inst.data, inst.indices, inst.indptr, inst.shape)
245+
246+
return sig, _construct
247+
248+
249+
@overload_method(CSType, "copy")
250+
def overload_sparse_copy(inst: CSType) -> None | Callable[[CSType], CSType]:
251+
if not isinstance(inst, CSType): # pragma: no cover
252+
return None
253+
254+
# nopython code:
255+
def copy(inst: CSType) -> CSType: # pragma: no cover
256+
return _sparse_copy(
257+
inst, inst.data.copy(), inst.indices.copy(), inst.indptr.copy(), inst.shape
258+
) # type: ignore[return-value]
259+
260+
return copy
261+
262+
263+
def register() -> None:
264+
"""Register the numba types, data models, and mappings between them and the Python types."""
265+
for cls, func in TYPEOF_FUNCS.items():
266+
typeof_impl.register(cls, func)
267+
for typ, model in MODELS.items():
268+
register_model(typ)(model)
269+
unbox(typ)(unbox_matrix)
270+
box(typ)(box_matrix)

Diff for: src/fast_array_utils/conv/scipy/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def to_dense(x: types.spmatrix | types.sparray, order: Literal["C", "F"] = "C")
4949

5050
out = np.zeros(x.shape, dtype=x.dtype, order=order)
5151
if x.format == "csr":
52-
_to_dense_csr_numba(x.indptr, x.indices, x.data, out)
52+
_to_dense_csr_numba(x, out)
5353
elif x.format == "csc":
54-
_to_dense_csc_numba(x.indptr, x.indices, x.data, out)
54+
_to_dense_csc_numba(x, out)
5555
else: # pragma: no cover
5656
out = x.toarray(order=order)
5757
return out

0 commit comments

Comments
 (0)