Skip to content

Commit 810a79d

Browse files
authored
0.3.1 (#84)
* adjust nbook * s to t * Initial work for symplectic integrators * odeint_symplectic added to API * hypersolver base * tutorial hypersolvers * new tutorial, better readme * better readme
1 parent 9cd2634 commit 810a79d

29 files changed

+2146
-291
lines changed

README.md

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
## A PyTorch based library for all things **neural differential equations**. Maintained by [DiffEqML](https://github.com/DiffEqML).
99

10-
![license](https://img.shields.io/badge/license-Apache%202.0-red.svg?)
10+
![license](https://img.shields.io/badge/license-Apache%202.0-blue.svg?)
1111
![CI](https://github.com/DiffEqML/torchdyn/actions/workflows/os-coverage.yml/badge.svg)
1212
[![Slack](https://img.shields.io/badge/slack-chat-blue.svg?logo=slack)](https://join.slack.com/t/diffeqml/shared_invite/zt-gq3jjj5x-LuHSB4m4gc9MsnvoF1UB6A)
1313
[![codecov](https://codecov.io/gh/DiffEqML/torchdyn/branch/master/graph/badge.svg)](https://codecov.io/gh/DiffEqML/torchdyn)
@@ -40,17 +40,12 @@ Contribute to the library with your benchmark and model variants! No need to rei
4040

4141
`pip install torchdyn`
4242

43-
* NOTE: temporarily requires additional manual installation of `torchsde`:
44-
45-
`pip install git+https://github.com/google-research/torchsde.git`
4643

4744
**Bleeding-edge** version:
4845

49-
`git clone https://github.com/DiffEqML/torchdyn.git`
50-
51-
`cd torchdyn`
46+
`git clone https://github.com/DiffEqML/torchdyn.git && cd torchdyn && python setup.py install`
5247

53-
`python setup.py install`
48+
Don't forget to install in your environment of choice if necessary. We offer an automated method for setting up your `torchdyn` environment designed specifically for contributors or those planning to tinker with the internals. Check `Contributing` below for more details.
5449

5550
## Documentation
5651
Check our [wiki](https://torchdyn.readthedocs.io/) for a full description of available features.
@@ -64,7 +59,7 @@ Interest in the blend of differential equations, deep learning and dynamical sys
6459
</p>
6560

6661

67-
By providing a centralized, easy-to-access collection of model templates, tutorial and application notebooks, we hope to speed-up research in this area and ultimately establish neural differential equations and implicit models into an effective tool for control, system identification and general machine learning tasks. W
62+
By providing a centralized, easy-to-access collection of model templates, tutorial and application notebooks, we hope to speed-up research in this area and ultimately establish neural differential equations and implicit models into an effective tool for control, system identification and general machine learning tasks.
6863

6964
#### Dependencies
7065
`torchdyn` leverages modern PyTorch best practices and handles training with `pytorch-lightning` [[6](https://github.com/PyTorchLightning/pytorch-lightning)]. We build Graph Neural ODEs utilizing the Graph Neural Networks (GNNs) API of `dgl` [[7](https://www.dgl.ai/)]. For a complete list of references, check `pyproject.toml`. We offer a complete suite of ODE solvers and sensitivity methods, extending the functionality offered by `torchdiffeq` [[1](https://arxiv.org/abs/1806.07366)]. We have light dependencies on `torchsde` [[7](https://arxiv.org/abs/2001.01328)] and `torchcde` [[8](https://arxiv.org/abs/2005.08926)].

benchmarks/numerics/bench_hypersolver.ipynb

Lines changed: 494 additions & 0 deletions
Large diffs are not rendered by default.

benchmarks/numerics/bench_odeint.ipynb

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,9 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 2,
5+
"execution_count": 1,
66
"metadata": {},
7-
"outputs": [
8-
{
9-
"ename": "ModuleNotFoundError",
10-
"evalue": "No module named 'torchdyn.numerics'",
11-
"output_type": "error",
12-
"traceback": [
13-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
14-
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
15-
"\u001b[0;32m<ipython-input-2-11c594d46672>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpyplot\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtorchdyn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumerics\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mEuler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mRungeKutta4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTsitouras45\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mDormandPrince45\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorchdyn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumerics\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0modeint\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
16-
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torchdyn.numerics'"
17-
]
18-
}
19-
],
7+
"outputs": [],
208
"source": [
219
"import torchdyn\n",
2210
"import torch\n",
@@ -471,9 +459,9 @@
471459
],
472460
"metadata": {
473461
"kernelspec": {
474-
"display_name": "Python 3",
462+
"display_name": "torchdyn",
475463
"language": "python",
476-
"name": "python3"
464+
"name": "torchdyn"
477465
},
478466
"language_info": {
479467
"codemirror_mode": {

benchmarks/numerics/bench_symplectic.ipynb

Lines changed: 161 additions & 0 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ papermill = "*"
3030
poethepoet = "^0.10.0"
3131

3232
[tool.poe.tasks]
33-
force-cuda11 = "python -m pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html"
33+
force-cuda11 = "python -m pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio===0.9.0 -f https://download.pytorch.org/whl/torch_stable.html"
3434

3535
[build-system]
3636
build-backend = "poetry.masonry.api"

torchdyn/core/problems.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,24 @@
1111

1212

1313
class ODEProblem(nn.Module):
14-
def __init__(self, vector_field, solver:Union[str, nn.Module], atol:float=1e-4, rtol:float=1e-4, sensitivity='autograd',
14+
def __init__(self, vector_field, solver:Union[str, nn.Module], order:int=1, atol:float=1e-4, rtol:float=1e-4, sensitivity='autograd',
1515
solver_adjoint:Union[str, nn.Module, None] = None, atol_adjoint:float=1e-6, rtol_adjoint:float=1e-6, seminorm:bool=False):
16+
"""An ODE Problem coupling a given vector field with solver and sensitivity algorithm to compute gradients w.r.t different quantities.
17+
18+
Args:
19+
vector_field ([Callable]): the vector field, called with `vector_field(t, x)` for `vector_field(x)`.
20+
In the second case, the Callable is automatically wrapped for consistency
21+
solver (Union[str, nn.Module]): [description]
22+
order (int, optional): [description]. Defaults to 1.
23+
atol (float, optional): [description]. Defaults to 1e-4.
24+
rtol (float, optional): [description]. Defaults to 1e-4.
25+
sensitivity (str, optional): [description]. Defaults to 'autograd'.
26+
solver_adjoint (Union[str, nn.Module, None], optional): [description]. Defaults to None.
27+
atol_adjoint (float, optional): [description]. Defaults to 1e-6.
28+
rtol_adjoint (float, optional): [description]. Defaults to 1e-6.
29+
seminorm (bool, optional): Indicates whether the a seminorm should be used for error estimation during adjoint backsolves. Defaults to False.
30+
31+
"""
1632
super().__init__()
1733
# instantiate solver at initialization
1834
if type(solver) == str:
@@ -38,7 +54,7 @@ def __init__(self, vector_field, solver:Union[str, nn.Module], atol:float=1e-4,
3854
vector_field = DEFuncBase(vector_field, has_time_arg=False)
3955
else: vector_field = DEFuncBase(vector_field, has_time_arg=True)
4056

41-
self.vf, self.sensalg = vector_field, sensitivity
57+
self.vf, self.order, self.sensalg = vector_field, order, sensitivity
4258
if len(tuple(self.vf.parameters())) > 0:
4359
self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()])
4460
else:
@@ -69,6 +85,16 @@ def forward(self, x:Tensor, t_span:Tensor):
6985

7086
class MultipleShootingProblem(nn.Module):
7187
def __init__(self, solver:str, vector_field, sensalg='autograd'):
88+
"""[summary]
89+
90+
Args:
91+
solver (str): [description]
92+
vector_field ([type]): [description]
93+
sensalg (str, optional): [description]. Defaults to 'autograd'.
94+
95+
Returns:
96+
[type]: [description]
97+
"""
7298
super().__init__()
7399
#
74100
self.solver
@@ -98,3 +124,4 @@ def forward(self, x0, t_span, atol=1e-4, rtol=1e-4, t_eval=[]):
98124
x0, t_span = prep_input(x0, t_span)
99125
t_eval, sol = self.odefunc(self.vf_params, x0, t_span, t_eval)
100126
return t_eval, sol
127+

torchdyn/models/README.md

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,22 @@
11
### Goals of `torchdyn`
2-
Our aim with `torchdyn` aims is to provide a unified, flexible API to the most recent advances in continuous and implicit neural networks. Examples include neural differential equations variants, e.g.
2+
Our aim with `torchdyn` aims is to provide a unified, flexible API to aid in the implementation of recent advances in continuous and implicit learning. Some models already implemented, either here under `torchdyn.models` or in the tutorials, are:
33

44
* Neural Ordinary Differential Equations (Neural ODE) [[1](https://arxiv.org/abs/1806.07366)]
5-
* Neural Stochastic Differential Equations (Neural SDE) [[7](https://arxiv.org/abs/1905.09883),[8](https://arxiv.org/abs/1906.02355)]
6-
* Graph Neural ODEs [[9](https://arxiv.org/abs/1911.07532)]
7-
* Hamiltonian Neural Networks [[10](https://arxiv.org/abs/1906.01563)]
5+
* Galerkin Neural ODE [[2](https://arxiv.org/abs/2002.08071)]
6+
* Neural Stochastic Differential Equations (Neural SDE) [[3](https://arxiv.org/abs/1905.09883),[4](https://arxiv.org/abs/1906.02355)]
7+
* Graph Neural ODEs [[5](https://arxiv.org/abs/1911.07532)]
8+
* Hamiltonian Neural Networks [[6](https://arxiv.org/abs/1906.01563)]
89

9-
Depth-variant versions,
10-
* ANODEv2 [[11](https://arxiv.org/abs/1906.04596)]
11-
* Galerkin Neural ODE [[12](https://arxiv.org/abs/2002.08071)]
10+
Recurrent or "hybrid" versions for sequences
11+
* ODE-RNN [[7](https://arxiv.org/abs/1907.03907)]
1212

13-
Recurrent or "hybrid" versions
14-
* ODE-RNN [[13](https://arxiv.org/abs/1907.03907)]
15-
* GRU-ODE-Bayes [[14](https://arxiv.org/abs/1905.12374)]
13+
Neural numerical methods
14+
* Hypersolvers [[12](https://arxiv.org/pdf/2007.09601.pdf)]
1615

1716
Augmentation strategies to relieve neural differential equations of their expressivity limitations and reduce the computational burden of the numerical solver
18-
* ANODE (0-augmentation) [[15](https://arxiv.org/abs/1904.01681)]
19-
* Input-layer augmentation [[16](https://arxiv.org/abs/2002.08071)]
20-
* Higher-order augmentation [[17](https://arxiv.org/abs/2002.08071)]
17+
* ANODE (0-augmentation) [[8](https://arxiv.org/abs/1904.01681)]
18+
* Input-layer augmentation [[9](https://arxiv.org/abs/2002.08071)]
19+
* Higher-order augmentation [[10](https://arxiv.org/abs/2002.08071)]
2120

22-
Alternative or modified adjoint training techniques
23-
* Integral loss adjoint [[18](https://arxiv.org/abs/2003.08063)]
24-
* Checkpointed adjoint [[19](https://arxiv.org/abs/1902.10298)]
21+
Various sensitivity algorithms / variants
22+
* Integral loss adjoint [[11](https://arxiv.org/abs/2003.08063)]

torchdyn/models/hybrid.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1+
import math
12
import torch
23
import torch.nn as nn
4+
from torch.distributions import Normal, kl_divergence
5+
import pytorch_lightning as pl
6+
import torchsde
7+
8+
from torchdyn.models import LSDEFunc
39

410

511
class HybridNeuralDE(nn.Module):
@@ -42,16 +48,6 @@ def _jump_latent_cell(self, *args):
4248
return self.jump(x_t, (h, c))
4349

4450

45-
46-
import math
47-
import torch
48-
import torch.nn as nn
49-
from torch.distributions import Normal, kl_divergence
50-
import pytorch_lightning as pl
51-
import torchsde
52-
53-
from torchdyn.models import LSDEFunc
54-
5551
class LatentNeuralSDE(NeuralSDE, pl.LightningModule): # pragma: no cover
5652
def __init__(self, post_drift, diffusion, prior_drift, sigma, theta, mu, options,
5753
noise_type, order, sensitivity, s_span, solver, atol, rtol, intloss):
@@ -74,9 +70,14 @@ def __init__(self, post_drift, diffusion, prior_drift, sigma, theta, mu, options
7470
self.qy0_logvar = nn.Parameter(torch.tensor([[logvar]]), requires_grad=True)
7571

7672
def forward(self, eps: torch.Tensor, s_span=None):
77-
"""
78-
:param s_span: (optional) Series span -- can pass extended version for additional regularization
79-
:param eps: Noise sample
73+
"""[summary]
74+
75+
Args:
76+
eps (torch.Tensor): [description]
77+
s_span ([type], optional): [description]. Defaults to None.
78+
79+
Returns:
80+
[type]: [description]
8081
"""
8182

8283
eps = eps.to(self.qy0_std)
@@ -86,15 +87,12 @@ def forward(self, eps: torch.Tensor, s_span=None):
8687
py0 = Normal(loc=self.py0_mean, scale=self.py0_std)
8788
logqp0 = kl_divergence(qy0, py0).sum(1).mean(0) # KL(time=0).
8889

89-
# Expand s_span to penalize out-of-datasets region and spread uncertainty -- moved
90-
# s_span_ext = torch.cat((torch.tensor([0.0]), self.s_span.to('cpu'), torch.tensor([2.0])))
91-
9290
if s_span is not None:
9391
s_span_ext = s_span
9492
else:
9593
s_span_ext = self.s_span.cpu()
9694

97-
zs, logqp = sdeint(sde=self.defunc, x0=x0, s_span=s_span_ext,
95+
zs, logqp = torchsde.sdeint(sde=self.defunc, x0=x0, s_span=s_span_ext,
9896
rtol=self.rtol, atol=self.atol, logqp=True, options=self.options,
9997
adaptive=self.adaptive, method=self.solver)
10098

@@ -104,24 +102,34 @@ def forward(self, eps: torch.Tensor, s_span=None):
104102
return zs, log_ratio
105103

106104
def sample_p(self, vis_span, n_sim, eps=None, bm=None, dt=0.01):
107-
"""
108-
:param vis_span:
109-
:param n_sim:
110-
:param eps:
111-
:param bm:
112-
:param dt:
105+
"""[summary]
106+
107+
Args:
108+
vis_span ([type]): [description]
109+
n_sim ([type]): [description]
110+
eps ([type], optional): [description]. Defaults to None.
111+
bm ([type], optional): [description]. Defaults to None.
112+
dt (float, optional): [description]. Defaults to 0.01.
113+
114+
Returns:
115+
[type]: [description]
113116
"""
114117
eps = torch.randn(n_sim, 1).to(self.py0_mean).to(self.device) if eps is None else eps
115118
y0 = self.py0_mean + eps.to(self.device) * self.py0_std
116119
return torchsde.sdeint(self.defunc, y0, vis_span, bm=bm, method='srk', dt=dt, names={'drift': 'h'})
117120

118121
def sample_q(self, vis_span, n_sim, eps=None, bm=None, dt=0.01):
119-
"""
120-
:param vis_span:
121-
:param n_sim:
122-
:param eps:
123-
:param bm:
124-
:param dt:
122+
"""[summary]
123+
124+
Args:
125+
vis_span ([type]): [description]
126+
n_sim ([type]): [description]
127+
eps ([type], optional): [description]. Defaults to None.
128+
bm ([type], optional): [description]. Defaults to None.
129+
dt (float, optional): [description]. Defaults to 0.01.
130+
131+
Returns:
132+
[type]: [description]
125133
"""
126134
eps = torch.randn(n_sim, 1).to(self.qy0_mean) if eps is None else eps
127135
y0 = self.qy0_mean + eps.to(self.device) * self.qy0_std

torchdyn/nn/galerkin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,10 @@ def __init__(self, bias=True, expfunc=Fourier(5), dilation=True, shift=True):
174174
def reset_parameters(self):
175175
torch.nn.init.zeros_(self.coeffs)
176176

177-
def calculate_weights(self, s):
178-
"Expands `s` following the chosen eigenbasis"
177+
def calculate_weights(self, t):
178+
"Expands `t` following the chosen eigenbasis"
179179
n_range = torch.linspace(0, self.deg, self.deg).to(self.coeffs.device)
180-
basis = self.expfunc(n_range, s*self.dilation.to(self.coeffs.device) + self.shift.to(self.coeffs.device))
180+
basis = self.expfunc(n_range, t*self.dilation.to(self.coeffs.device) + self.shift.to(self.coeffs.device))
181181
B = []
182182
for i in range(self.n_eig):
183183
Bin = torch.eye(self.deg).to(self.coeffs.device)

torchdyn/numerics/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212

13-
from torchdyn.numerics.solvers import Euler, RungeKutta4, Tsitouras45, DormandPrince45
14-
from torchdyn.numerics.odeint import odeint
13+
from torchdyn.numerics.solvers import Euler, RungeKutta4, Tsitouras45, DormandPrince45, AsynchronousLeapfrog
14+
from torchdyn.numerics.hypersolvers import HyperEuler
15+
from torchdyn.numerics.odeint import odeint, odeint_symplectic
1516

1617

17-
__all__ = ['odeint', 'Euler', 'RungeKutta4', 'DormandPrince45', 'Tsitouras45']
18+
__all__ = ['odeint', 'odeint_symplectic', 'Euler', 'RungeKutta4', 'DormandPrince45', 'Tsitouras45',
19+
'AsynchronousLeapfrog', 'HyperEuler']
20+

torchdyn/numerics/hypersolvers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
import torch.nn as nn
3+
from torchdyn.numerics.solvers import Euler, RungeKutta4
4+
5+
class HyperEuler(Euler):
6+
def __init__(self, hypernet, dtype=torch.float32):
7+
super().__init__(dtype)
8+
self.hypernet = hypernet
9+
self.stepping_class = 'fixed'
10+
11+
def step(self, f, x, t, dt, k1=None):
12+
_, _, x_sol = super().step(f, x, t, dt, k1)
13+
return None, None, x_sol + dt**2 * self.hypernet(t, x)
14+
15+
class HyperRungeKutta4(RungeKutta4):
16+
def __init__(self, hypernet, dtype=torch.float32):
17+
super().__init__(dtype)
18+
self.hypernet = hypernet
19+
20+
def step(self, f, x, t, dt, k1=None):
21+
_, _, x_sol = super().step(f, x, t, dt, k1)
22+
return None, None, x_sol + dt**5 * self.hypernet(t, x)

0 commit comments

Comments
 (0)