|
1 |
| -from numpy import ( |
2 |
| - add, |
3 |
| - bitwise_and, |
4 |
| - bitwise_not, |
5 |
| - bitwise_or, |
6 |
| - bitwise_xor, |
7 |
| - can_cast, |
8 |
| - ceil, |
9 |
| - complex64, |
10 |
| - complex128, |
11 |
| - cos, |
12 |
| - cosh, |
13 |
| - divide, |
14 |
| - e, |
15 |
| - exp, |
16 |
| - expm1, |
17 |
| - finfo, |
18 |
| - float16, |
19 |
| - float32, |
20 |
| - float64, |
21 |
| - floor, |
22 |
| - floor_divide, |
23 |
| - greater, |
24 |
| - greater_equal, |
25 |
| - iinfo, |
26 |
| - imag, |
27 |
| - inf, |
28 |
| - int8, |
29 |
| - int16, |
30 |
| - int32, |
31 |
| - int64, |
32 |
| - less, |
33 |
| - less_equal, |
34 |
| - log, |
35 |
| - log1p, |
36 |
| - log2, |
37 |
| - log10, |
38 |
| - logaddexp, |
39 |
| - logical_and, |
40 |
| - logical_not, |
41 |
| - logical_or, |
42 |
| - logical_xor, |
43 |
| - multiply, |
44 |
| - nan, |
45 |
| - negative, |
46 |
| - newaxis, |
47 |
| - not_equal, |
48 |
| - pi, |
49 |
| - positive, |
50 |
| - real, |
51 |
| - remainder, |
52 |
| - sign, |
53 |
| - sin, |
54 |
| - sinh, |
55 |
| - sqrt, |
56 |
| - square, |
57 |
| - subtract, |
58 |
| - tan, |
59 |
| - tanh, |
60 |
| - trunc, |
61 |
| - uint8, |
62 |
| - uint16, |
63 |
| - uint32, |
64 |
| - uint64, |
65 |
| -) |
66 |
| -from numpy import arccos as acos |
67 |
| -from numpy import arccosh as acosh |
68 |
| -from numpy import arcsin as asin |
69 |
| -from numpy import arcsinh as asinh |
70 |
| -from numpy import arctan as atan |
71 |
| -from numpy import arctan2 as atan2 |
72 |
| -from numpy import arctanh as atanh |
73 |
| -from numpy import bool_ as bool |
74 |
| -from numpy import invert as bitwise_invert |
75 |
| -from numpy import left_shift as bitwise_left_shift |
76 |
| -from numpy import power as pow |
77 |
| -from numpy import right_shift as bitwise_right_shift |
| 1 | +import os |
| 2 | +from contextvars import ContextVar |
| 3 | +from enum import Enum |
78 | 4 |
|
79 |
| -from ._common import ( |
80 |
| - SparseArray, |
81 |
| - abs, |
82 |
| - all, |
83 |
| - any, |
84 |
| - asarray, |
85 |
| - asnumpy, |
86 |
| - astype, |
87 |
| - broadcast_arrays, |
88 |
| - broadcast_to, |
89 |
| - concat, |
90 |
| - concatenate, |
91 |
| - dot, |
92 |
| - einsum, |
93 |
| - empty, |
94 |
| - empty_like, |
95 |
| - equal, |
96 |
| - eye, |
97 |
| - full, |
98 |
| - full_like, |
99 |
| - isfinite, |
100 |
| - isinf, |
101 |
| - isnan, |
102 |
| - matmul, |
103 |
| - max, |
104 |
| - mean, |
105 |
| - min, |
106 |
| - moveaxis, |
107 |
| - nonzero, |
108 |
| - ones, |
109 |
| - ones_like, |
110 |
| - outer, |
111 |
| - pad, |
112 |
| - permute_dims, |
113 |
| - prod, |
114 |
| - reshape, |
115 |
| - round, |
116 |
| - squeeze, |
117 |
| - stack, |
118 |
| - std, |
119 |
| - sum, |
120 |
| - tensordot, |
121 |
| - var, |
122 |
| - vecdot, |
123 |
| - zeros, |
124 |
| - zeros_like, |
125 |
| -) |
126 |
| -from ._compressed import GCXS |
127 |
| -from ._coo import COO, as_coo |
128 |
| -from ._coo.common import ( |
129 |
| - argmax, |
130 |
| - argmin, |
131 |
| - argwhere, |
132 |
| - asCOO, |
133 |
| - clip, |
134 |
| - diagonal, |
135 |
| - diagonalize, |
136 |
| - expand_dims, |
137 |
| - flip, |
138 |
| - isneginf, |
139 |
| - isposinf, |
140 |
| - kron, |
141 |
| - matrix_transpose, |
142 |
| - nanmax, |
143 |
| - nanmean, |
144 |
| - nanmin, |
145 |
| - nanprod, |
146 |
| - nanreduce, |
147 |
| - nansum, |
148 |
| - result_type, |
149 |
| - roll, |
150 |
| - sort, |
151 |
| - take, |
152 |
| - tril, |
153 |
| - triu, |
154 |
| - unique_counts, |
155 |
| - unique_values, |
156 |
| - where, |
157 |
| -) |
158 |
| -from ._dok import DOK |
159 |
| -from ._io import load_npz, save_npz |
160 |
| -from ._umath import elemwise |
161 |
| -from ._utils import random |
162 | 5 | from ._version import __version__, __version_tuple__ # noqa: F401
|
163 | 6 |
|
164 |
| -__all__ = [ |
165 |
| - "COO", |
166 |
| - "DOK", |
167 |
| - "GCXS", |
168 |
| - "SparseArray", |
169 |
| - "abs", |
170 |
| - "acos", |
171 |
| - "acosh", |
172 |
| - "add", |
173 |
| - "all", |
174 |
| - "any", |
175 |
| - "argmax", |
176 |
| - "argmin", |
177 |
| - "argwhere", |
178 |
| - "asCOO", |
179 |
| - "as_coo", |
180 |
| - "asarray", |
181 |
| - "asin", |
182 |
| - "asinh", |
183 |
| - "asnumpy", |
184 |
| - "astype", |
185 |
| - "atan", |
186 |
| - "atan2", |
187 |
| - "atanh", |
188 |
| - "bitwise_and", |
189 |
| - "bitwise_invert", |
190 |
| - "bitwise_left_shift", |
191 |
| - "bitwise_not", |
192 |
| - "bitwise_or", |
193 |
| - "bitwise_right_shift", |
194 |
| - "bitwise_xor", |
195 |
| - "bool", |
196 |
| - "broadcast_arrays", |
197 |
| - "broadcast_to", |
198 |
| - "can_cast", |
199 |
| - "ceil", |
200 |
| - "clip", |
201 |
| - "complex128", |
202 |
| - "complex64", |
203 |
| - "concat", |
204 |
| - "concatenate", |
205 |
| - "cos", |
206 |
| - "cosh", |
207 |
| - "diagonal", |
208 |
| - "diagonalize", |
209 |
| - "divide", |
210 |
| - "dot", |
211 |
| - "e", |
212 |
| - "einsum", |
213 |
| - "elemwise", |
214 |
| - "empty", |
215 |
| - "empty_like", |
216 |
| - "equal", |
217 |
| - "exp", |
218 |
| - "expand_dims", |
219 |
| - "expm1", |
220 |
| - "eye", |
221 |
| - "finfo", |
222 |
| - "flip", |
223 |
| - "float16", |
224 |
| - "float32", |
225 |
| - "float64", |
226 |
| - "floor", |
227 |
| - "floor_divide", |
228 |
| - "full", |
229 |
| - "full_like", |
230 |
| - "greater", |
231 |
| - "greater_equal", |
232 |
| - "iinfo", |
233 |
| - "imag", |
234 |
| - "inf", |
235 |
| - "int16", |
236 |
| - "int32", |
237 |
| - "int64", |
238 |
| - "int8", |
239 |
| - "isfinite", |
240 |
| - "isinf", |
241 |
| - "isnan", |
242 |
| - "isneginf", |
243 |
| - "isposinf", |
244 |
| - "kron", |
245 |
| - "less", |
246 |
| - "less_equal", |
247 |
| - "load_npz", |
248 |
| - "log", |
249 |
| - "log10", |
250 |
| - "log1p", |
251 |
| - "log2", |
252 |
| - "logaddexp", |
253 |
| - "logical_and", |
254 |
| - "logical_not", |
255 |
| - "logical_or", |
256 |
| - "logical_xor", |
257 |
| - "matmul", |
258 |
| - "matrix_transpose", |
259 |
| - "max", |
260 |
| - "mean", |
261 |
| - "min", |
262 |
| - "moveaxis", |
263 |
| - "multiply", |
264 |
| - "nan", |
265 |
| - "nanmax", |
266 |
| - "nanmean", |
267 |
| - "nanmin", |
268 |
| - "nanprod", |
269 |
| - "nanreduce", |
270 |
| - "nansum", |
271 |
| - "negative", |
272 |
| - "newaxis", |
273 |
| - "nonzero", |
274 |
| - "not_equal", |
275 |
| - "ones", |
276 |
| - "ones_like", |
277 |
| - "outer", |
278 |
| - "pad", |
279 |
| - "permute_dims", |
280 |
| - "pi", |
281 |
| - "positive", |
282 |
| - "pow", |
283 |
| - "prod", |
284 |
| - "random", |
285 |
| - "real", |
286 |
| - "remainder", |
287 |
| - "reshape", |
288 |
| - "result_type", |
289 |
| - "roll", |
290 |
| - "round", |
291 |
| - "save_npz", |
292 |
| - "sign", |
293 |
| - "sin", |
294 |
| - "sinh", |
295 |
| - "sort", |
296 |
| - "sqrt", |
297 |
| - "square", |
298 |
| - "squeeze", |
299 |
| - "stack", |
300 |
| - "std", |
301 |
| - "subtract", |
302 |
| - "sum", |
303 |
| - "take", |
304 |
| - "tan", |
305 |
| - "tanh", |
306 |
| - "tensordot", |
307 |
| - "tril", |
308 |
| - "triu", |
309 |
| - "trunc", |
310 |
| - "uint16", |
311 |
| - "uint32", |
312 |
| - "uint64", |
313 |
| - "uint8", |
314 |
| - "unique_counts", |
315 |
| - "unique_values", |
316 |
| - "var", |
317 |
| - "vecdot", |
318 |
| - "where", |
319 |
| - "zeros", |
320 |
| - "zeros_like", |
321 |
| -] |
322 |
| - |
323 | 7 | __array_api_version__ = "2022.12"
|
| 8 | + |
| 9 | + |
| 10 | +class BackendType(Enum): |
| 11 | + PyData = "PyData" |
| 12 | + Finch = "Finch" |
| 13 | + |
| 14 | + |
| 15 | +_ENV_VAR_NAME = "SPARSE_BACKEND" |
| 16 | + |
| 17 | +backend_var = ContextVar("backend", default=BackendType.PyData) |
| 18 | + |
| 19 | +if _ENV_VAR_NAME in os.environ: |
| 20 | + backend_var.set(BackendType[os.environ[_ENV_VAR_NAME]]) |
| 21 | + |
| 22 | + |
| 23 | +class Backend: |
| 24 | + def __init__(self, backend=BackendType.PyData): |
| 25 | + self.backend = backend |
| 26 | + self.token = None |
| 27 | + |
| 28 | + def __enter__(self): |
| 29 | + token = backend_var.set(self.backend) |
| 30 | + self.token = token |
| 31 | + |
| 32 | + def __exit__(self, exc_type, exc_value, traceback): |
| 33 | + backend_var.reset(self.token) |
| 34 | + self.token = None |
| 35 | + |
| 36 | + @staticmethod |
| 37 | + def get_backend_module(): |
| 38 | + backend = backend_var.get() |
| 39 | + if backend == BackendType.PyData: |
| 40 | + import sparse.pydata_backend as backend_module |
| 41 | + elif backend == BackendType.Finch: |
| 42 | + import sparse.finch_backend as backend_module |
| 43 | + else: |
| 44 | + raise ValueError(f"Invalid backend identifier: {backend}") |
| 45 | + return backend_module |
| 46 | + |
| 47 | + |
| 48 | +def __getattr__(attr): |
| 49 | + if attr == "pydata_backend": |
| 50 | + import sparse.pydata_backend as backend_module |
| 51 | + |
| 52 | + return backend_module |
| 53 | + if attr == "finch_backend": |
| 54 | + import sparse.finch_backend as backend_module |
| 55 | + |
| 56 | + return backend_module |
| 57 | + |
| 58 | + return getattr(Backend.get_backend_module(), attr) |
0 commit comments