Skip to content

Commit 32e6e0b

Browse files
authored
API: Establish pydata_backend (#646)
1 parent 92e8078 commit 32e6e0b

40 files changed

+437
-353
lines changed

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ tests = [
3737
]
3838
tox = ["sparse[tests]", "tox"]
3939
all = ["sparse[docs,tox]", "matrepr"]
40+
finch = ["finch-tensor"]
4041

4142
[project.urls]
4243
Documentation = "https://sparse.pydata.org/"
@@ -46,7 +47,7 @@ Repository = "https://github.com/pydata/sparse.git"
4647
Discussions = "https://github.com/pydata/sparse/discussions"
4748

4849
[project.entry-points.numba_extensions]
49-
init = "sparse._numba_extension:_init_extension"
50+
init = "sparse.pydata_backend._numba_extension:_init_extension"
5051

5152
[tool.setuptools.packages.find]
5253
where = ["."]

sparse/__init__.py

+54-319
Original file line numberDiff line numberDiff line change
@@ -1,323 +1,58 @@
1-
from numpy import (
2-
add,
3-
bitwise_and,
4-
bitwise_not,
5-
bitwise_or,
6-
bitwise_xor,
7-
can_cast,
8-
ceil,
9-
complex64,
10-
complex128,
11-
cos,
12-
cosh,
13-
divide,
14-
e,
15-
exp,
16-
expm1,
17-
finfo,
18-
float16,
19-
float32,
20-
float64,
21-
floor,
22-
floor_divide,
23-
greater,
24-
greater_equal,
25-
iinfo,
26-
imag,
27-
inf,
28-
int8,
29-
int16,
30-
int32,
31-
int64,
32-
less,
33-
less_equal,
34-
log,
35-
log1p,
36-
log2,
37-
log10,
38-
logaddexp,
39-
logical_and,
40-
logical_not,
41-
logical_or,
42-
logical_xor,
43-
multiply,
44-
nan,
45-
negative,
46-
newaxis,
47-
not_equal,
48-
pi,
49-
positive,
50-
real,
51-
remainder,
52-
sign,
53-
sin,
54-
sinh,
55-
sqrt,
56-
square,
57-
subtract,
58-
tan,
59-
tanh,
60-
trunc,
61-
uint8,
62-
uint16,
63-
uint32,
64-
uint64,
65-
)
66-
from numpy import arccos as acos
67-
from numpy import arccosh as acosh
68-
from numpy import arcsin as asin
69-
from numpy import arcsinh as asinh
70-
from numpy import arctan as atan
71-
from numpy import arctan2 as atan2
72-
from numpy import arctanh as atanh
73-
from numpy import bool_ as bool
74-
from numpy import invert as bitwise_invert
75-
from numpy import left_shift as bitwise_left_shift
76-
from numpy import power as pow
77-
from numpy import right_shift as bitwise_right_shift
1+
import os
2+
from contextvars import ContextVar
3+
from enum import Enum
784

79-
from ._common import (
80-
SparseArray,
81-
abs,
82-
all,
83-
any,
84-
asarray,
85-
asnumpy,
86-
astype,
87-
broadcast_arrays,
88-
broadcast_to,
89-
concat,
90-
concatenate,
91-
dot,
92-
einsum,
93-
empty,
94-
empty_like,
95-
equal,
96-
eye,
97-
full,
98-
full_like,
99-
isfinite,
100-
isinf,
101-
isnan,
102-
matmul,
103-
max,
104-
mean,
105-
min,
106-
moveaxis,
107-
nonzero,
108-
ones,
109-
ones_like,
110-
outer,
111-
pad,
112-
permute_dims,
113-
prod,
114-
reshape,
115-
round,
116-
squeeze,
117-
stack,
118-
std,
119-
sum,
120-
tensordot,
121-
var,
122-
vecdot,
123-
zeros,
124-
zeros_like,
125-
)
126-
from ._compressed import GCXS
127-
from ._coo import COO, as_coo
128-
from ._coo.common import (
129-
argmax,
130-
argmin,
131-
argwhere,
132-
asCOO,
133-
clip,
134-
diagonal,
135-
diagonalize,
136-
expand_dims,
137-
flip,
138-
isneginf,
139-
isposinf,
140-
kron,
141-
matrix_transpose,
142-
nanmax,
143-
nanmean,
144-
nanmin,
145-
nanprod,
146-
nanreduce,
147-
nansum,
148-
result_type,
149-
roll,
150-
sort,
151-
take,
152-
tril,
153-
triu,
154-
unique_counts,
155-
unique_values,
156-
where,
157-
)
158-
from ._dok import DOK
159-
from ._io import load_npz, save_npz
160-
from ._umath import elemwise
161-
from ._utils import random
1625
from ._version import __version__, __version_tuple__ # noqa: F401
1636

164-
__all__ = [
165-
"COO",
166-
"DOK",
167-
"GCXS",
168-
"SparseArray",
169-
"abs",
170-
"acos",
171-
"acosh",
172-
"add",
173-
"all",
174-
"any",
175-
"argmax",
176-
"argmin",
177-
"argwhere",
178-
"asCOO",
179-
"as_coo",
180-
"asarray",
181-
"asin",
182-
"asinh",
183-
"asnumpy",
184-
"astype",
185-
"atan",
186-
"atan2",
187-
"atanh",
188-
"bitwise_and",
189-
"bitwise_invert",
190-
"bitwise_left_shift",
191-
"bitwise_not",
192-
"bitwise_or",
193-
"bitwise_right_shift",
194-
"bitwise_xor",
195-
"bool",
196-
"broadcast_arrays",
197-
"broadcast_to",
198-
"can_cast",
199-
"ceil",
200-
"clip",
201-
"complex128",
202-
"complex64",
203-
"concat",
204-
"concatenate",
205-
"cos",
206-
"cosh",
207-
"diagonal",
208-
"diagonalize",
209-
"divide",
210-
"dot",
211-
"e",
212-
"einsum",
213-
"elemwise",
214-
"empty",
215-
"empty_like",
216-
"equal",
217-
"exp",
218-
"expand_dims",
219-
"expm1",
220-
"eye",
221-
"finfo",
222-
"flip",
223-
"float16",
224-
"float32",
225-
"float64",
226-
"floor",
227-
"floor_divide",
228-
"full",
229-
"full_like",
230-
"greater",
231-
"greater_equal",
232-
"iinfo",
233-
"imag",
234-
"inf",
235-
"int16",
236-
"int32",
237-
"int64",
238-
"int8",
239-
"isfinite",
240-
"isinf",
241-
"isnan",
242-
"isneginf",
243-
"isposinf",
244-
"kron",
245-
"less",
246-
"less_equal",
247-
"load_npz",
248-
"log",
249-
"log10",
250-
"log1p",
251-
"log2",
252-
"logaddexp",
253-
"logical_and",
254-
"logical_not",
255-
"logical_or",
256-
"logical_xor",
257-
"matmul",
258-
"matrix_transpose",
259-
"max",
260-
"mean",
261-
"min",
262-
"moveaxis",
263-
"multiply",
264-
"nan",
265-
"nanmax",
266-
"nanmean",
267-
"nanmin",
268-
"nanprod",
269-
"nanreduce",
270-
"nansum",
271-
"negative",
272-
"newaxis",
273-
"nonzero",
274-
"not_equal",
275-
"ones",
276-
"ones_like",
277-
"outer",
278-
"pad",
279-
"permute_dims",
280-
"pi",
281-
"positive",
282-
"pow",
283-
"prod",
284-
"random",
285-
"real",
286-
"remainder",
287-
"reshape",
288-
"result_type",
289-
"roll",
290-
"round",
291-
"save_npz",
292-
"sign",
293-
"sin",
294-
"sinh",
295-
"sort",
296-
"sqrt",
297-
"square",
298-
"squeeze",
299-
"stack",
300-
"std",
301-
"subtract",
302-
"sum",
303-
"take",
304-
"tan",
305-
"tanh",
306-
"tensordot",
307-
"tril",
308-
"triu",
309-
"trunc",
310-
"uint16",
311-
"uint32",
312-
"uint64",
313-
"uint8",
314-
"unique_counts",
315-
"unique_values",
316-
"var",
317-
"vecdot",
318-
"where",
319-
"zeros",
320-
"zeros_like",
321-
]
322-
3237
__array_api_version__ = "2022.12"
8+
9+
10+
class BackendType(Enum):
11+
PyData = "PyData"
12+
Finch = "Finch"
13+
14+
15+
_ENV_VAR_NAME = "SPARSE_BACKEND"
16+
17+
backend_var = ContextVar("backend", default=BackendType.PyData)
18+
19+
if _ENV_VAR_NAME in os.environ:
20+
backend_var.set(BackendType[os.environ[_ENV_VAR_NAME]])
21+
22+
23+
class Backend:
24+
def __init__(self, backend=BackendType.PyData):
25+
self.backend = backend
26+
self.token = None
27+
28+
def __enter__(self):
29+
token = backend_var.set(self.backend)
30+
self.token = token
31+
32+
def __exit__(self, exc_type, exc_value, traceback):
33+
backend_var.reset(self.token)
34+
self.token = None
35+
36+
@staticmethod
37+
def get_backend_module():
38+
backend = backend_var.get()
39+
if backend == BackendType.PyData:
40+
import sparse.pydata_backend as backend_module
41+
elif backend == BackendType.Finch:
42+
import sparse.finch_backend as backend_module
43+
else:
44+
raise ValueError(f"Invalid backend identifier: {backend}")
45+
return backend_module
46+
47+
48+
def __getattr__(attr):
49+
if attr == "pydata_backend":
50+
import sparse.pydata_backend as backend_module
51+
52+
return backend_module
53+
if attr == "finch_backend":
54+
import sparse.finch_backend as backend_module
55+
56+
return backend_module
57+
58+
return getattr(Backend.get_backend_module(), attr)

0 commit comments

Comments
 (0)