Skip to content

Commit e539c21

Browse files
committed
fix ad compatibility
1 parent 2a5adea commit e539c21

File tree

1 file changed

+36
-30
lines changed

1 file changed

+36
-30
lines changed
Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,57 @@
11
import functools
22
from functools import partial
33

4+
import jax
45
from jax import tree_util
56
from jax.core import Primitive
67
from jax.interpreters import ad
78

89
__all__ = [
9-
'defjvp',
10+
'defjvp',
1011
]
1112

1213

1314
def defjvp(primitive, *jvp_rules):
14-
"""Define JVP rules for any JAX primitive.
15+
"""Define JVP rules for any JAX primitive.
1516
16-
This function is similar to ``jax.interpreters.ad.defjvp``.
17-
However, the JAX one only supports primitive with ``multiple_results=False``.
18-
``brainpy.math.defjvp`` enables to define the independent JVP rule for
19-
each input parameter no matter ``multiple_results=False/True``.
17+
This function is similar to ``jax.interpreters.ad.defjvp``.
18+
However, the JAX one only supports primitive with ``multiple_results=False``.
19+
``brainpy.math.defjvp`` enables to define the independent JVP rule for
20+
each input parameter no matter ``multiple_results=False/True``.
2021
21-
For examples, please see ``test_ad_support.py``.
22+
For examples, please see ``test_ad_support.py``.
2223
23-
Args:
24-
primitive: Primitive, XLACustomOp.
25-
*jvp_rules: The JVP translation rule for each primal.
26-
"""
27-
assert isinstance(primitive, Primitive)
28-
if primitive.multiple_results:
29-
ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
30-
else:
31-
ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive)
24+
Args:
25+
primitive: Primitive, XLACustomOp.
26+
*jvp_rules: The JVP translation rule for each primal.
27+
"""
28+
assert isinstance(primitive, Primitive)
29+
if primitive.multiple_results:
30+
ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
31+
else:
32+
ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive)
3233

3334

3435
def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params):
35-
assert primitive.multiple_results
36-
val_out = tuple(primitive.bind(*primals, **params))
37-
tree = tree_util.tree_structure(val_out)
38-
tangents_out = []
39-
for rule, t in zip(jvp_rules, tangents):
40-
if rule is not None and type(t) is not ad.Zero:
41-
r = tuple(rule(t, *primals, **params))
42-
tangents_out.append(r)
43-
assert tree_util.tree_structure(r) == tree
44-
return val_out, functools.reduce(_add_tangents,
45-
tangents_out,
46-
tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out))
36+
assert primitive.multiple_results
37+
val_out = tuple(primitive.bind(*primals, **params))
38+
tree = tree_util.tree_structure(val_out)
39+
tangents_out = []
40+
for rule, t in zip(jvp_rules, tangents):
41+
if rule is not None and type(t) is not ad.Zero:
42+
r = tuple(rule(t, *primals, **params))
43+
tangents_out.append(r)
44+
assert tree_util.tree_structure(r) == tree
45+
return val_out, functools.reduce(
46+
_add_tangents,
47+
tangents_out,
48+
tree_util.tree_map(
49+
# compatible with JAX 0.4.34
50+
lambda a: ad.Zero.from_primal_value(a) if jax.__version__ >= '0.4.34' else ad.Zero.from_value(a),
51+
val_out
52+
)
53+
)
4754

4855

4956
def _add_tangents(xs, ys):
50-
return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))
51-
57+
return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))

0 commit comments

Comments
 (0)