Skip to content

Commit da2f4ac

Browse files
committed
compatible with jax>=0.8.0
1 parent a3569aa commit da2f4ac

File tree

5 files changed

+12
-154
lines changed

5 files changed

+12
-154
lines changed

brainpy/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17-
__version__ = "2.7.1"
18-
__version_info__ = (2, 7, 1)
17+
__version__ = "2.7.2"
18+
__version_info__ = (2, 7, 2)
1919

2020

2121
from brainpy import _errors as errors

brainpy/math/object_transform/base.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,10 @@
3030
from brainpy.math.defaults import defaults
3131
from brainpy.math.modes import Mode
3232
from brainpy.math.ndarray import (Array, )
33-
from brainpy.math.object_transform.collectors import (ArrayCollector, Collector)
34-
from brainpy.math.object_transform.naming import (
35-
get_unique_name,
36-
check_name_uniqueness
37-
)
33+
from brainpy.math.object_transform.collectors import ArrayCollector, Collector
34+
from brainpy.math.object_transform.naming import get_unique_name, check_name_uniqueness
3835
from brainpy.math.object_transform.variables import (
39-
Variable, VariableView, TrainVar,
40-
VarList, VarDict
36+
Variable, VariableView, TrainVar, VarList, VarDict
4137
)
4238
from brainpy.math.sharding import BATCH_AXIS
4339

brainpy/math/object_transform/variables.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,17 @@
1414
# ==============================================================================
1515
from typing import Optional, Any, Sequence
1616

17+
import brainstate
1718
import jax
1819
import numpy as np
20+
from brainstate._state import record_state_value_read, record_state_value_write
1921
from jax import numpy as jnp
2022
from jax.dtypes import canonicalize_dtype
2123
from jax.tree_util import register_pytree_node_class
2224

23-
import brainstate
2425
from brainpy._errors import MathError
2526
from brainpy.math.ndarray import Array
2627
from brainpy.math.sharding import BATCH_AXIS
27-
from brainstate._state import record_state_value_read, record_state_value_write
2828

2929
__all__ = [
3030
'Variable',

brainpy/math/sparse/utils.py

Lines changed: 4 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -13,57 +13,21 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
# ==============================================================================
16-
import warnings
17-
from functools import partial
16+
1817
from typing import Tuple
1918

20-
import jax
21-
import numpy as np
22-
from jax import core, numpy as jnp
23-
from jax import lax
24-
from jax.interpreters import batching
25-
from jax.interpreters import mlir, ad
26-
from jax.tree_util import tree_flatten, tree_unflatten
27-
from jaxlib import gpu_sparse
19+
from jax import numpy as jnp
20+
from jax.experimental.sparse import csr_todense
2821

2922
from brainpy.math.interoperability import as_jax
3023

31-
if jax.__version__ >= '0.5.0':
32-
from jax.extend.core import Primitive
33-
else:
34-
from jax.core import Primitive
35-
3624
__all__ = [
3725
'coo_to_csr',
3826
'csr_to_coo',
3927
'csr_to_dense'
4028
]
4129

4230

43-
def _general_batching_rule(prim, args, axes, **kwargs):
44-
batch_axes, batch_args, non_batch_args = [], {}, {}
45-
for ax_i, ax in enumerate(axes):
46-
if ax is None:
47-
non_batch_args[f'ax{ax_i}'] = args[ax_i]
48-
else:
49-
batch_args[f'ax{ax_i}'] = args[ax_i] if ax == 0 else jnp.moveaxis(args[ax_i], ax, 0)
50-
batch_axes.append(ax_i)
51-
52-
def f(_, x):
53-
pars = tuple([(x[f'ax{i}'] if i in batch_axes else non_batch_args[f'ax{i}'])
54-
for i in range(len(axes))])
55-
return 0, prim.bind(*pars, **kwargs)
56-
57-
_, outs = lax.scan(f, 0, batch_args)
58-
out_vals, out_tree = tree_flatten(outs)
59-
out_dim = tree_unflatten(out_tree, (0,) * len(out_vals))
60-
return outs, out_dim
61-
62-
63-
def _register_general_batching(prim):
64-
batching.primitive_batchers[prim] = partial(_general_batching_rule, prim)
65-
66-
6731
def coo_to_csr(
6832
pre_ids: jnp.ndarray,
6933
post_ids: jnp.ndarray,
@@ -97,108 +61,6 @@ def csr_to_coo(
9761
return jnp.cumsum(jnp.zeros_like(indices).at[indptr].add(1)) - 1, indices
9862

9963

100-
def csr_to_csc():
101-
pass
102-
103-
104-
def coo_to_dense(
105-
data: jnp.ndarray,
106-
rows: jnp.ndarray,
107-
cols: jnp.ndarray,
108-
*,
109-
shape: Tuple[int, int]
110-
) -> jnp.ndarray:
111-
pass
112-
113-
114-
def csr_to_dense(
115-
data: jnp.ndarray,
116-
indices: jnp.ndarray,
117-
indptr: jnp.ndarray,
118-
*,
119-
shape: Tuple[int, int]
120-
) -> jnp.ndarray:
121-
data = as_jax(data)
122-
indices = as_jax(indices)
123-
indptr = as_jax(indptr)
124-
return csr_to_dense_p.bind(data, indices, indptr, shape=shape)
125-
126-
127-
def _coo_extract(row, col, mat):
128-
"""Extract values of dense matrix mat at given COO indices."""
129-
return mat[row, col]
130-
131-
132-
def _csr_extract(indices, indptr, mat):
133-
"""Extract values of dense matrix mat at given CSR indices."""
134-
return _coo_extract(*csr_to_coo(indices, indptr), mat)
135-
136-
137-
def _coo_todense(data, row, col, *, shape):
138-
"""Convert CSR-format sparse matrix to a dense matrix.
139-
140-
Args:
141-
data : array of shape ``(nse,)``.
142-
row : array of shape ``(nse,)``
143-
col : array of shape ``(nse,)`` and dtype ``row.dtype``
144-
shape : COOInfo object containing matrix metadata
145-
146-
Returns:
147-
mat : array with specified shape and dtype matching ``data``
148-
"""
149-
return jnp.zeros(shape, data.dtype).at[row, col].add(data)
150-
151-
152-
def _csr_to_dense_impl(data, indices, indptr, *, shape):
153-
return _coo_todense(data, *csr_to_coo(indices, indptr), shape=shape)
154-
155-
156-
def _csr_to_dense_abstract_eval(data, indices, indptr, *, shape):
157-
assert data.ndim == indices.ndim == indptr.ndim == 1
158-
assert indices.dtype == indptr.dtype
159-
assert data.shape == indices.shape
160-
assert indptr.shape[0] == shape[0] + 1
161-
return core.ShapedArray(shape, data.dtype)
162-
163-
164-
_csr_to_dense_lowering = mlir.lower_fun(_csr_to_dense_impl, multiple_results=False)
165-
166-
167-
def _csr_to_dense_gpu_lowering(ctx, data, indices, indptr, *, shape):
168-
data_aval, indices_aval, _ = ctx.avals_in
169-
dtype = data_aval.dtype
170-
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
171-
warnings.warn(f"csr_todense cusparse/hipsparse lowering not available for dtype={dtype}. "
172-
"Falling back to default implementation.",
173-
UserWarning)
174-
return _csr_to_dense_lowering(ctx, data, indices, indptr, shape=shape)
175-
return [gpu_sparse.cuda_csr_todense(data, indices, indptr,
176-
shape=shape, data_dtype=dtype,
177-
index_dtype=indices_aval.dtype)]
178-
179-
180-
def _csr_to_dense_jvp(data_dot, data, indices, indptr, *, shape):
181-
return csr_to_dense(data_dot, indices, indptr, shape=shape)
182-
183-
184-
def _csr_to_dense_transpose(ct, data, indices, indptr, *, shape):
185-
# Note: we assume that transpose has the same sparsity pattern.
186-
# Can we check this?
187-
assert ad.is_undefined_primal(data)
188-
if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
189-
raise ValueError("Cannot transpose with respect to sparse indices")
190-
assert ct.shape == shape
191-
assert indices.aval.dtype == indptr.aval.dtype
192-
assert ct.dtype == data.aval.dtype
193-
return _csr_extract(indices, indptr, ct), indices, indptr
64+
csr_to_dense = csr_todense
19465

19566

196-
csr_to_dense_p = Primitive('csr_to_dense')
197-
csr_to_dense_p.def_impl(_csr_to_dense_impl)
198-
csr_to_dense_p.def_abstract_eval(_csr_to_dense_abstract_eval)
199-
ad.defjvp(csr_to_dense_p, _csr_to_dense_jvp, None, None)
200-
ad.primitive_transposes[csr_to_dense_p] = _csr_to_dense_transpose
201-
mlir.register_lowering(csr_to_dense_p, _csr_to_dense_lowering)
202-
_register_general_batching(csr_to_dense_p)
203-
if gpu_sparse.cuda_is_supported:
204-
mlir.register_lowering(csr_to_dense_p, _csr_to_dense_gpu_lowering, platform='cuda')

docs_state/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ Learn more
117117
.. grid-item::
118118
:columns: 6 6 6 4
119119

120-
.. card:: :material-regular:`data_exploration;2em` Classical APIs
120+
.. card:: :material-regular:`data_exploration;2em` ``brainpy`` APIs
121121
:class-card: sd-text-black sd-bg-light
122122
:link: https://brainpy.readthedocs.io/
123123

0 commit comments

Comments
 (0)