|
1 | 1 | import functools |
2 | 2 | from functools import partial |
3 | 3 |
|
| 4 | +import jax |
4 | 5 | from jax import tree_util |
5 | 6 | from jax.core import Primitive |
6 | 7 | from jax.interpreters import ad |
7 | 8 |
|
8 | 9 | __all__ = [ |
9 | | - 'defjvp', |
| 10 | + 'defjvp', |
10 | 11 | ] |
11 | 12 |
|
12 | 13 |
|
13 | 14 | def defjvp(primitive, *jvp_rules): |
14 | | - """Define JVP rules for any JAX primitive. |
| 15 | + """Define JVP rules for any JAX primitive. |
15 | 16 |
|
16 | | - This function is similar to ``jax.interpreters.ad.defjvp``. |
17 | | - However, the JAX one only supports primitive with ``multiple_results=False``. |
18 | | - ``brainpy.math.defjvp`` enables to define the independent JVP rule for |
19 | | - each input parameter no matter ``multiple_results=False/True``. |
| 17 | + This function is similar to ``jax.interpreters.ad.defjvp``. |
| 18 | + However, the JAX one only supports primitive with ``multiple_results=False``. |
| 19 | + ``brainpy.math.defjvp`` enables to define the independent JVP rule for |
| 20 | + each input parameter no matter ``multiple_results=False/True``. |
20 | 21 |
|
21 | | - For examples, please see ``test_ad_support.py``. |
| 22 | + For examples, please see ``test_ad_support.py``. |
22 | 23 |
|
23 | | - Args: |
24 | | - primitive: Primitive, XLACustomOp. |
25 | | - *jvp_rules: The JVP translation rule for each primal. |
26 | | - """ |
27 | | - assert isinstance(primitive, Primitive) |
28 | | - if primitive.multiple_results: |
29 | | - ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive) |
30 | | - else: |
31 | | - ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive) |
| 24 | + Args: |
| 25 | + primitive: Primitive, XLACustomOp. |
| 26 | + *jvp_rules: The JVP translation rule for each primal. |
| 27 | + """ |
| 28 | + assert isinstance(primitive, Primitive) |
| 29 | + if primitive.multiple_results: |
| 30 | + ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive) |
| 31 | + else: |
| 32 | + ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive) |
32 | 33 |
|
33 | 34 |
|
34 | 35 | def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params): |
35 | | - assert primitive.multiple_results |
36 | | - val_out = tuple(primitive.bind(*primals, **params)) |
37 | | - tree = tree_util.tree_structure(val_out) |
38 | | - tangents_out = [] |
39 | | - for rule, t in zip(jvp_rules, tangents): |
40 | | - if rule is not None and type(t) is not ad.Zero: |
41 | | - r = tuple(rule(t, *primals, **params)) |
42 | | - tangents_out.append(r) |
43 | | - assert tree_util.tree_structure(r) == tree |
44 | | - return val_out, functools.reduce(_add_tangents, |
45 | | - tangents_out, |
46 | | - tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out)) |
| 36 | + assert primitive.multiple_results |
| 37 | + val_out = tuple(primitive.bind(*primals, **params)) |
| 38 | + tree = tree_util.tree_structure(val_out) |
| 39 | + tangents_out = [] |
| 40 | + for rule, t in zip(jvp_rules, tangents): |
| 41 | + if rule is not None and type(t) is not ad.Zero: |
| 42 | + r = tuple(rule(t, *primals, **params)) |
| 43 | + tangents_out.append(r) |
| 44 | + assert tree_util.tree_structure(r) == tree |
| 45 | + return val_out, functools.reduce( |
| 46 | + _add_tangents, |
| 47 | + tangents_out, |
| 48 | + tree_util.tree_map( |
| 49 | + # compatible with JAX 0.4.34 |
| 50 | + lambda a: ad.Zero.from_primal_value(a) if jax.__version__ >= '0.4.34' else ad.Zero.from_value(a), |
| 51 | + val_out |
| 52 | + ) |
| 53 | + ) |
47 | 54 |
|
48 | 55 |
|
49 | 56 | def _add_tangents(xs, ys): |
50 | | - return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero)) |
51 | | - |
| 57 | + return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero)) |
0 commit comments