Skip to content

Commit 0015964

Browse files
committed
Update JAX setup for READTHEDOCS and travis
1 parent 768b022 commit 0015964

File tree

4 files changed

+17
-5
lines changed

4 files changed

+17
-5
lines changed

.travis.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@ install:
2828

2929
- conda install -c conda-forge scikit-optimize
3030
- pip install tensorflow tensorflow-probability
31-
- conda install pytorch torchvision torchaudio cpuonly -c pytorch
31+
- conda install pytorch cpuonly -c pytorch
32+
- pip install jax flax optax
3233

3334
script:
3435
# Your test script goes here
3536
- DDEBACKEND=tensorflow.compat.v1 python -c "import deepxde"
3637
- DDEBACKEND=tensorflow python -c "import deepxde"
3738
- DDEBACKEND=pytorch python -c "import deepxde"
39+
- DDEBACKEND=jax python -c "import deepxde"

deepxde/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
if os.environ.get("READTHEDOCS") == "True":
2424
# The backend should be tensorflow/tensorflow.compat.v1 to ensure backend.tf is not
2525
# None.
26+
from . import jax
2627
from . import pytorch
2728
from . import tensorflow
2829
from . import tensorflow_compat_v1

docs/index.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,12 @@ If you are looking for information on a specific function, class or method, this
8282
modules/deepxde
8383
modules/deepxde.data
8484
modules/deepxde.geometry
85-
modules/deepxde.icbcs
85+
modules/deepxde.icbc
8686
modules/deepxde.nn
87-
modules/deepxde.nn.tensorflow_compat_v1
88-
modules/deepxde.nn.tensorflow
87+
modules/deepxde.nn.jax
8988
modules/deepxde.nn.pytorch
89+
modules/deepxde.nn.tensorflow
90+
modules/deepxde.nn.tensorflow_compat_v1
9091
modules/deepxde.optimizers
9192
modules/deepxde.utils
9293

docs/requirements.txt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,15 @@ numpy
33
scikit-learn
44
scikit-optimize
55
scipy
6+
docutils<0.18 # https://github.com/readthedocs/readthedocs.org/issues/8616
7+
8+
# TensorFlow 1.x
69
tensorflow>=2.2.0
10+
# TensorFlow 2.x
711
tensorflow-probability>=0.10.0
12+
# PyTorch
813
torch
9-
docutils<0.18 # https://github.com/readthedocs/readthedocs.org/issues/8616
14+
# JAX
15+
jax
16+
flax
17+
optax

0 commit comments

Comments
 (0)