Skip to content

Commit 7e4575c

Browse files
committed
Merge branch 'fix-ad-support' into braintaichi-op
2 parents dc63758 + 2f9952f commit 7e4575c

File tree

4 files changed

+15
-2
lines changed

4 files changed

+15
-2
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ jobs:
7979
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
8080
pip uninstall brainpy -y
8181
python setup.py install
82+
pip install jax==0.4.30
83+
pip install jaxlib==0.4.30
8284
- name: Test with pytest
8385
run: |
8486
cd brainpy

brainpy/_src/dnn/tests/test_activation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import brainpy.math as bm
55

66

7+
78
class Test_Activation(parameterized.TestCase):
89

910
@parameterized.product(

brainpy/_src/dnn/tests/test_conv_layers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
# -*- coding: utf-8 -*-
2+
import platform
23

34
import jax.numpy as jnp
5+
import pytest
46
from absl.testing import absltest
57
from absl.testing import parameterized
68

79
import brainpy as bp
810
import brainpy.math as bm
911

12+
if platform.system() == 'Darwin':
13+
pytest.skip('skip Mac OS', allow_module_level=True)
14+
1015

1116
class TestConv(parameterized.TestCase):
1217
def test_Conv2D_img(self):

brainpy/_src/math/op_register/ad_support.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,14 @@ def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params):
4141
r = tuple(rule(t, *primals, **params))
4242
tangents_out.append(r)
4343
assert tree_util.tree_structure(r) == tree
44-
return val_out, functools.reduce(_add_tangents,
44+
try:
45+
return val_out, functools.reduce(_add_tangents,
4546
tangents_out,
46-
tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out))
47+
tree_util.tree_map(lambda a: ad.Zero.from_primal_value(a), val_out))
48+
except:
49+
return val_out, functools.reduce(_add_tangents,
50+
tangents_out,
51+
tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out))
4752

4853

4954
def _add_tangents(xs, ys):

0 commit comments

Comments
 (0)