|
| 1 | +# SPDX-License-Identifier: MPL-2.0 |
| 2 | +"""Numba support for sparse arrays and matrices.""" |
| 3 | + |
| 4 | +# taken from https://github.com/numba/numba-scipy/blob/release0.4/numba_scipy/sparse.py |
| 5 | +# See https://numba.pydata.org/numba-doc/dev/extending/ |
| 6 | +from __future__ import annotations |
| 7 | + |
| 8 | +from typing import TYPE_CHECKING, cast |
| 9 | + |
| 10 | +import numba.core.types as nbtypes |
| 11 | +import numpy as np |
| 12 | +from numba.core import cgutils |
| 13 | +from numba.core.imputils import impl_ret_borrowed |
| 14 | +from numba.extending import ( |
| 15 | + NativeValue, |
| 16 | + box, |
| 17 | + intrinsic, |
| 18 | + make_attribute_wrapper, |
| 19 | + models, |
| 20 | + overload, |
| 21 | + overload_attribute, |
| 22 | + overload_method, |
| 23 | + register_model, |
| 24 | + typeof_impl, |
| 25 | + unbox, |
| 26 | +) |
| 27 | +from scipy import sparse |
| 28 | + |
| 29 | + |
| 30 | +if TYPE_CHECKING: |
| 31 | + from collections.abc import Callable, Mapping, Sequence |
| 32 | + from typing import Any, ClassVar, Literal |
| 33 | + |
| 34 | + from llvmlite.ir import IRBuilder, Value |
| 35 | + from numba.core.base import BaseContext |
| 36 | + from numba.core.datamodel.manager import DataModelManager |
| 37 | + from numba.core.extending import BoxContext, TypingContext, UnboxContext |
| 38 | + from numba.core.typing.templates import Signature |
| 39 | + from numba.core.typing.typeof import _TypeofContext |
| 40 | + from numpy.typing import NDArray |
| 41 | + |
| 42 | + from fast_array_utils.types import CSBase |
| 43 | + |
| 44 | + |
| 45 | +class CSType(nbtypes.Type): |
| 46 | + """A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`. |
| 47 | +
|
| 48 | + This is an abstract base class for the actually used, registered types in `TYPES` below. |
| 49 | + It collects information about the type (e.g. field dtypes) for later use in the data model. |
| 50 | + """ |
| 51 | + |
| 52 | + name: ClassVar[Literal["csr_matrix", "csc_matrix", "csr_array", "csc_array"]] |
| 53 | + cls: ClassVar[type[CSBase]] |
| 54 | + |
| 55 | + @classmethod |
| 56 | + def instance_class( |
| 57 | + cls, |
| 58 | + data: NDArray[np.number[Any]], |
| 59 | + indices: NDArray[np.integer[Any]], |
| 60 | + indptr: NDArray[np.integer[Any]], |
| 61 | + shape: tuple[int, int], # actually tuple[int, ...] for sparray subclasses |
| 62 | + ) -> CSBase: |
| 63 | + return cls.cls((data, indices, indptr), shape, copy=False) |
| 64 | + |
| 65 | + def __init__(self, ndim: int, *, dtype: nbtypes.Type, dtype_ind: nbtypes.Type) -> None: |
| 66 | + self.dtype = nbtypes.DType(dtype) |
| 67 | + self.dtype_ind = nbtypes.DType(dtype_ind) |
| 68 | + self.data = nbtypes.Array(dtype, 1, "A") |
| 69 | + self.indices = nbtypes.Array(dtype_ind, 1, "A") |
| 70 | + self.indptr = nbtypes.Array(dtype_ind, 1, "A") |
| 71 | + self.shape = nbtypes.UniTuple(nbtypes.intp, ndim) |
| 72 | + super().__init__(self.name) |
| 73 | + |
| 74 | + @property |
| 75 | + def key(self) -> tuple[str | nbtypes.Type, ...]: |
| 76 | + return (self.name, self.dtype, self.dtype_ind) |
| 77 | + |
| 78 | + |
| 79 | +# make data model attributes available in numba functions |
| 80 | +for attr in ["data", "indices", "indptr", "shape"]: |
| 81 | + make_attribute_wrapper(CSType, attr, attr) |
| 82 | + |
| 83 | + |
| 84 | +def make_typeof_fn(typ: type[CSType]) -> Callable[[CSBase, _TypeofContext], CSType]: |
| 85 | + """Create a `typeof` function that maps a scipy matrix/array type to a numba `Type`.""" |
| 86 | + |
| 87 | + def typeof(val: CSBase, c: _TypeofContext) -> CSType: |
| 88 | + if val.indptr.dtype != val.indices.dtype: # pragma: no cover |
| 89 | + msg = "indptr and indices must have the same dtype" |
| 90 | + raise TypeError(msg) |
| 91 | + data = cast("nbtypes.Array", typeof_impl(val.data, c)) |
| 92 | + indptr = cast("nbtypes.Array", typeof_impl(val.indptr, c)) |
| 93 | + return typ(val.ndim, dtype=data.dtype, dtype_ind=indptr.dtype) |
| 94 | + |
| 95 | + return typeof |
| 96 | + |
| 97 | + |
| 98 | +if TYPE_CHECKING: |
| 99 | + _CSModelBase = models.StructModel[CSType] |
| 100 | +else: |
| 101 | + _CSModelBase = models.StructModel |
| 102 | + |
| 103 | + |
| 104 | +class CSModel(_CSModelBase): |
| 105 | + """Numba data model for compressed sparse matrices. |
| 106 | +
|
| 107 | + This is the class that is used by numba to lower the array types. |
| 108 | + """ |
| 109 | + |
| 110 | + def __init__(self, dmm: DataModelManager, fe_type: CSType) -> None: |
| 111 | + members = [ |
| 112 | + ("data", fe_type.data), |
| 113 | + ("indices", fe_type.indices), |
| 114 | + ("indptr", fe_type.indptr), |
| 115 | + ("shape", fe_type.shape), |
| 116 | + ] |
| 117 | + super().__init__(dmm, fe_type, members) |
| 118 | + |
| 119 | + |
| 120 | +# create all the actual types and data models |
| 121 | +CLASSES: Sequence[type[CSBase]] = [ |
| 122 | + sparse.csr_matrix, |
| 123 | + sparse.csc_matrix, |
| 124 | + sparse.csr_array, |
| 125 | + sparse.csc_array, |
| 126 | +] |
| 127 | +TYPES: Sequence[type[CSType]] = [ |
| 128 | + type(f"{cls.__name__}Type", (CSType,), {"cls": cls, "name": cls.__name__}) for cls in CLASSES |
| 129 | +] |
| 130 | +TYPEOF_FUNCS: Mapping[type[CSBase], Callable[[CSBase, _TypeofContext], CSType]] = { |
| 131 | + typ.cls: make_typeof_fn(typ) for typ in TYPES |
| 132 | +} |
| 133 | +MODELS: Mapping[type[CSType], type[CSModel]] = { |
| 134 | + typ: type(f"{typ.cls.__name__}Model", (CSModel,), {}) for typ in TYPES |
| 135 | +} |
| 136 | + |
| 137 | + |
| 138 | +def unbox_matrix(typ: CSType, obj: Value, c: UnboxContext) -> NativeValue: |
| 139 | + """Convert a Python cs{rc}_{matrix,array} to a Numba value.""" |
| 140 | + struct_proxy_cls = cgutils.create_struct_proxy(typ) |
| 141 | + struct_ptr = struct_proxy_cls(c.context, c.builder) |
| 142 | + |
| 143 | + data = c.pyapi.object_getattr_string(obj, "data") |
| 144 | + indices = c.pyapi.object_getattr_string(obj, "indices") |
| 145 | + indptr = c.pyapi.object_getattr_string(obj, "indptr") |
| 146 | + shape = c.pyapi.object_getattr_string(obj, "shape") |
| 147 | + |
| 148 | + struct_ptr.data = c.unbox(typ.data, data).value |
| 149 | + struct_ptr.indices = c.unbox(typ.indices, indices).value |
| 150 | + struct_ptr.indptr = c.unbox(typ.indptr, indptr).value |
| 151 | + struct_ptr.shape = c.unbox(typ.shape, shape).value |
| 152 | + |
| 153 | + c.pyapi.decref(data) |
| 154 | + c.pyapi.decref(indices) |
| 155 | + c.pyapi.decref(indptr) |
| 156 | + c.pyapi.decref(shape) |
| 157 | + |
| 158 | + is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit) |
| 159 | + is_error = c.builder.load(is_error_ptr) |
| 160 | + |
| 161 | + return NativeValue(struct_ptr._getvalue(), is_error=is_error) # noqa: SLF001 |
| 162 | + |
| 163 | + |
| 164 | +def box_matrix(typ: CSType, val: NativeValue, c: BoxContext) -> Value: |
| 165 | + """Convert numba value into a Python cs{rc}_{matrix,array}.""" |
| 166 | + struct_proxy_cls = cgutils.create_struct_proxy(typ) |
| 167 | + struct_ptr = struct_proxy_cls(c.context, c.builder, value=val) |
| 168 | + |
| 169 | + data_obj = c.box(typ.data, struct_ptr.data) |
| 170 | + indices_obj = c.box(typ.indices, struct_ptr.indices) |
| 171 | + indptr_obj = c.box(typ.indptr, struct_ptr.indptr) |
| 172 | + shape_obj = c.box(typ.shape, struct_ptr.shape) |
| 173 | + |
| 174 | + c.pyapi.incref(data_obj) |
| 175 | + c.pyapi.incref(indices_obj) |
| 176 | + c.pyapi.incref(indptr_obj) |
| 177 | + c.pyapi.incref(shape_obj) |
| 178 | + |
| 179 | + cls_obj = c.pyapi.unserialize(c.pyapi.serialize_object(typ.instance_class)) |
| 180 | + obj = c.pyapi.call_function_objargs(cls_obj, (data_obj, indices_obj, indptr_obj, shape_obj)) |
| 181 | + |
| 182 | + c.pyapi.decref(data_obj) |
| 183 | + c.pyapi.decref(indices_obj) |
| 184 | + c.pyapi.decref(indptr_obj) |
| 185 | + c.pyapi.decref(shape_obj) |
| 186 | + |
| 187 | + return obj |
| 188 | + |
| 189 | + |
| 190 | +# See https://numba.readthedocs.io/en/stable/extending/overloading-guide.html |
| 191 | +@overload(np.shape) |
| 192 | +def overload_sparse_shape(x: CSType) -> None | Callable[[CSType], nbtypes.UniTuple]: |
| 193 | + if not isinstance(x, CSType): # pragma: no cover |
| 194 | + return None |
| 195 | + |
| 196 | + # nopython code: |
| 197 | + def shape(x: CSType) -> nbtypes.UniTuple: # pragma: no cover |
| 198 | + return x.shape |
| 199 | + |
| 200 | + return shape |
| 201 | + |
| 202 | + |
| 203 | +@overload_attribute(CSType, "ndim") |
| 204 | +def overload_sparse_ndim(inst: CSType) -> None | Callable[[CSType], int]: |
| 205 | + if not isinstance(inst, CSType): # pragma: no cover |
| 206 | + return None |
| 207 | + |
| 208 | + # nopython code: |
| 209 | + def ndim(inst: CSType) -> int: # pragma: no cover |
| 210 | + return len(inst.shape) |
| 211 | + |
| 212 | + return ndim |
| 213 | + |
| 214 | + |
| 215 | +@intrinsic |
| 216 | +def _sparse_copy( |
| 217 | + typingctx: TypingContext, # noqa: ARG001 |
| 218 | + inst: CSType, |
| 219 | + data: nbtypes.Array, # noqa: ARG001 |
| 220 | + indices: nbtypes.Array, # noqa: ARG001 |
| 221 | + indptr: nbtypes.Array, # noqa: ARG001 |
| 222 | + shape: nbtypes.UniTuple, # noqa: ARG001 |
| 223 | +) -> tuple[Signature, Callable[..., NativeValue]]: |
| 224 | + def _construct( |
| 225 | + context: BaseContext, |
| 226 | + builder: IRBuilder, |
| 227 | + sig: Signature, |
| 228 | + args: tuple[Value, Value, Value, Value, Value], |
| 229 | + ) -> NativeValue: |
| 230 | + struct_proxy_cls = cgutils.create_struct_proxy(sig.return_type) |
| 231 | + struct = struct_proxy_cls(context, builder) |
| 232 | + _, data, indices, indptr, shape = args |
| 233 | + struct.data = data |
| 234 | + struct.indices = indices |
| 235 | + struct.indptr = indptr |
| 236 | + struct.shape = shape |
| 237 | + return impl_ret_borrowed( |
| 238 | + context, |
| 239 | + builder, |
| 240 | + sig.return_type, |
| 241 | + struct._getvalue(), # noqa: SLF001 |
| 242 | + ) |
| 243 | + |
| 244 | + sig = inst(inst, inst.data, inst.indices, inst.indptr, inst.shape) |
| 245 | + |
| 246 | + return sig, _construct |
| 247 | + |
| 248 | + |
| 249 | +@overload_method(CSType, "copy") |
| 250 | +def overload_sparse_copy(inst: CSType) -> None | Callable[[CSType], CSType]: |
| 251 | + if not isinstance(inst, CSType): # pragma: no cover |
| 252 | + return None |
| 253 | + |
| 254 | + # nopython code: |
| 255 | + def copy(inst: CSType) -> CSType: # pragma: no cover |
| 256 | + return _sparse_copy( |
| 257 | + inst, inst.data.copy(), inst.indices.copy(), inst.indptr.copy(), inst.shape |
| 258 | + ) # type: ignore[return-value] |
| 259 | + |
| 260 | + return copy |
| 261 | + |
| 262 | + |
| 263 | +def register() -> None: |
| 264 | + """Register the numba types, data models, and mappings between them and the Python types.""" |
| 265 | + for cls, func in TYPEOF_FUNCS.items(): |
| 266 | + typeof_impl.register(cls, func) |
| 267 | + for typ, model in MODELS.items(): |
| 268 | + register_model(typ)(model) |
| 269 | + unbox(typ)(unbox_matrix) |
| 270 | + box(typ)(box_matrix) |
0 commit comments