|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 | # ============================================================================== |
16 | | -import warnings |
17 | | -from functools import partial |
| 16 | + |
18 | 17 | from typing import Tuple |
19 | 18 |
|
20 | | -import jax |
21 | | -import numpy as np |
22 | | -from jax import core, numpy as jnp |
23 | | -from jax import lax |
24 | | -from jax.interpreters import batching |
25 | | -from jax.interpreters import mlir, ad |
26 | | -from jax.tree_util import tree_flatten, tree_unflatten |
27 | | -from jaxlib import gpu_sparse |
| 19 | +from jax import numpy as jnp |
| 20 | +from jax.experimental.sparse import csr_todense |
28 | 21 |
|
29 | 22 | from brainpy.math.interoperability import as_jax |
30 | 23 |
|
31 | | -if jax.__version__ >= '0.5.0': |
32 | | - from jax.extend.core import Primitive |
33 | | -else: |
34 | | - from jax.core import Primitive |
35 | | - |
36 | 24 | __all__ = [ |
37 | 25 | 'coo_to_csr', |
38 | 26 | 'csr_to_coo', |
39 | 27 | 'csr_to_dense' |
40 | 28 | ] |
41 | 29 |
|
42 | 30 |
|
43 | | -def _general_batching_rule(prim, args, axes, **kwargs): |
44 | | - batch_axes, batch_args, non_batch_args = [], {}, {} |
45 | | - for ax_i, ax in enumerate(axes): |
46 | | - if ax is None: |
47 | | - non_batch_args[f'ax{ax_i}'] = args[ax_i] |
48 | | - else: |
49 | | - batch_args[f'ax{ax_i}'] = args[ax_i] if ax == 0 else jnp.moveaxis(args[ax_i], ax, 0) |
50 | | - batch_axes.append(ax_i) |
51 | | - |
52 | | - def f(_, x): |
53 | | - pars = tuple([(x[f'ax{i}'] if i in batch_axes else non_batch_args[f'ax{i}']) |
54 | | - for i in range(len(axes))]) |
55 | | - return 0, prim.bind(*pars, **kwargs) |
56 | | - |
57 | | - _, outs = lax.scan(f, 0, batch_args) |
58 | | - out_vals, out_tree = tree_flatten(outs) |
59 | | - out_dim = tree_unflatten(out_tree, (0,) * len(out_vals)) |
60 | | - return outs, out_dim |
61 | | - |
62 | | - |
63 | | -def _register_general_batching(prim): |
64 | | - batching.primitive_batchers[prim] = partial(_general_batching_rule, prim) |
65 | | - |
66 | | - |
67 | 31 | def coo_to_csr( |
68 | 32 | pre_ids: jnp.ndarray, |
69 | 33 | post_ids: jnp.ndarray, |
@@ -97,108 +61,6 @@ def csr_to_coo( |
97 | 61 | return jnp.cumsum(jnp.zeros_like(indices).at[indptr].add(1)) - 1, indices |
98 | 62 |
|
99 | 63 |
|
100 | | -def csr_to_csc(): |
101 | | - pass |
102 | | - |
103 | | - |
104 | | -def coo_to_dense( |
105 | | - data: jnp.ndarray, |
106 | | - rows: jnp.ndarray, |
107 | | - cols: jnp.ndarray, |
108 | | - *, |
109 | | - shape: Tuple[int, int] |
110 | | -) -> jnp.ndarray: |
111 | | - pass |
112 | | - |
113 | | - |
114 | | -def csr_to_dense( |
115 | | - data: jnp.ndarray, |
116 | | - indices: jnp.ndarray, |
117 | | - indptr: jnp.ndarray, |
118 | | - *, |
119 | | - shape: Tuple[int, int] |
120 | | -) -> jnp.ndarray: |
121 | | - data = as_jax(data) |
122 | | - indices = as_jax(indices) |
123 | | - indptr = as_jax(indptr) |
124 | | - return csr_to_dense_p.bind(data, indices, indptr, shape=shape) |
125 | | - |
126 | | - |
127 | | -def _coo_extract(row, col, mat): |
128 | | - """Extract values of dense matrix mat at given COO indices.""" |
129 | | - return mat[row, col] |
130 | | - |
131 | | - |
132 | | -def _csr_extract(indices, indptr, mat): |
133 | | - """Extract values of dense matrix mat at given CSR indices.""" |
134 | | - return _coo_extract(*csr_to_coo(indices, indptr), mat) |
135 | | - |
136 | | - |
137 | | -def _coo_todense(data, row, col, *, shape): |
138 | | - """Convert CSR-format sparse matrix to a dense matrix. |
139 | | -
|
140 | | - Args: |
141 | | - data : array of shape ``(nse,)``. |
142 | | - row : array of shape ``(nse,)`` |
143 | | - col : array of shape ``(nse,)`` and dtype ``row.dtype`` |
144 | | - shape : COOInfo object containing matrix metadata |
145 | | -
|
146 | | - Returns: |
147 | | - mat : array with specified shape and dtype matching ``data`` |
148 | | - """ |
149 | | - return jnp.zeros(shape, data.dtype).at[row, col].add(data) |
150 | | - |
151 | | - |
152 | | -def _csr_to_dense_impl(data, indices, indptr, *, shape): |
153 | | - return _coo_todense(data, *csr_to_coo(indices, indptr), shape=shape) |
154 | | - |
155 | | - |
156 | | -def _csr_to_dense_abstract_eval(data, indices, indptr, *, shape): |
157 | | - assert data.ndim == indices.ndim == indptr.ndim == 1 |
158 | | - assert indices.dtype == indptr.dtype |
159 | | - assert data.shape == indices.shape |
160 | | - assert indptr.shape[0] == shape[0] + 1 |
161 | | - return core.ShapedArray(shape, data.dtype) |
162 | | - |
163 | | - |
164 | | -_csr_to_dense_lowering = mlir.lower_fun(_csr_to_dense_impl, multiple_results=False) |
165 | | - |
166 | | - |
167 | | -def _csr_to_dense_gpu_lowering(ctx, data, indices, indptr, *, shape): |
168 | | - data_aval, indices_aval, _ = ctx.avals_in |
169 | | - dtype = data_aval.dtype |
170 | | - if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): |
171 | | - warnings.warn(f"csr_todense cusparse/hipsparse lowering not available for dtype={dtype}. " |
172 | | - "Falling back to default implementation.", |
173 | | - UserWarning) |
174 | | - return _csr_to_dense_lowering(ctx, data, indices, indptr, shape=shape) |
175 | | - return [gpu_sparse.cuda_csr_todense(data, indices, indptr, |
176 | | - shape=shape, data_dtype=dtype, |
177 | | - index_dtype=indices_aval.dtype)] |
178 | | - |
179 | | - |
180 | | -def _csr_to_dense_jvp(data_dot, data, indices, indptr, *, shape): |
181 | | - return csr_to_dense(data_dot, indices, indptr, shape=shape) |
182 | | - |
183 | | - |
184 | | -def _csr_to_dense_transpose(ct, data, indices, indptr, *, shape): |
185 | | - # Note: we assume that transpose has the same sparsity pattern. |
186 | | - # Can we check this? |
187 | | - assert ad.is_undefined_primal(data) |
188 | | - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): |
189 | | - raise ValueError("Cannot transpose with respect to sparse indices") |
190 | | - assert ct.shape == shape |
191 | | - assert indices.aval.dtype == indptr.aval.dtype |
192 | | - assert ct.dtype == data.aval.dtype |
193 | | - return _csr_extract(indices, indptr, ct), indices, indptr |
| 64 | +csr_to_dense = csr_todense |
194 | 65 |
|
195 | 66 |
|
196 | | -csr_to_dense_p = Primitive('csr_to_dense') |
197 | | -csr_to_dense_p.def_impl(_csr_to_dense_impl) |
198 | | -csr_to_dense_p.def_abstract_eval(_csr_to_dense_abstract_eval) |
199 | | -ad.defjvp(csr_to_dense_p, _csr_to_dense_jvp, None, None) |
200 | | -ad.primitive_transposes[csr_to_dense_p] = _csr_to_dense_transpose |
201 | | -mlir.register_lowering(csr_to_dense_p, _csr_to_dense_lowering) |
202 | | -_register_general_batching(csr_to_dense_p) |
203 | | -if gpu_sparse.cuda_is_supported: |
204 | | - mlir.register_lowering(csr_to_dense_p, _csr_to_dense_gpu_lowering, platform='cuda') |
0 commit comments