You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Implementation of neural Ordinary Differential Equations (ODE) built for [deeplearning4j](https://deeplearning4j.org/).
4
12
5
13
[[Arxiv](https://arxiv.org/abs/1806.07366)]
6
14
7
15
[[Pytorch repo by paper authors](https://github.com/rtqichen/torchdiffeq)]
8
16
9
-
NOTE: This is very much a work in progress and given that I haven't touched a differential equation since school chances
10
-
are that there are conceptual misunderstandings.
11
-
12
-
The performance of the MNIST example is in line with the results presented in the paper, but given the simplicity of that dataset this is no guarantee of correct implementation.
17
+
[[Very good blog post](https://julialang.org/blog/2019/01/fluxdiffeq)]
13
18
14
19
## Getting Started
15
20
@@ -21,20 +26,125 @@ cd neuralODE4j
21
26
mvn install
22
27
```
23
28
24
-
Currently only the MNIST toy experiment from the paper is implemented [[link]](./src/main/java/examples)
29
+
I will try to create a maven artifact whenever I find the time for it. Please file an issue for this if you are interested.
30
+
31
+
Implementations of the MNIST and spiral generation toy experiments from the paper can be found under examples [[link]](./src/main/java/examples)
32
+
33
+
## Usage
34
+
35
+
The class [OdeVertex](./src/main/java/ode/vertex/conf/OdeVertex.java) is used to add an arbitrary graph of Layers or GraphVertices as an ODE block in a ComputationGraph.
36
+
37
+
OdeVertex extends GraphVertex and can be added to a GraphBuilder just as any other vertex. It has a similar API as GraphBuilder for adding
38
+
layers and vertices.
39
+
40
+
Example:
41
+
```
42
+
final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder()
// Build the OdeVertex. The resulting "inner graph" will be treated as an ODE
63
+
.build(), "normalLayer0")
64
+
65
+
// Layers/vertices can be added to the graph after the ODE block
66
+
.addLayer("normalLayer1", new BatchNormalization.Builder().build(), "odeBlock")
67
+
.setOutputs("output")
68
+
.addLayer("output", new CnnLossLayer(), "normalLayer1")
69
+
.build());
70
+
```
71
+
72
+
An inherent constraint to the method itself is that the output of the last layer in the OdeVertex must have the exact same
73
+
shape as the input to the first layer in the OdeVertex.
74
+
75
+
Note that OdeVertex.Builder requires a NeuralNetConfiguration.Builder as constructor input. This is because DL4J does not set graph wise
76
+
default values for things like updaters and weight initialization for vertices so the only way to apply them to the
77
+
Layers of the OdeVertex is to pass in the global configuration. Putting it as a required constructor argument will
78
+
hopefully make this harder to forget. It is of course possible to have a separate set of default values for the layers
79
+
of the OdeVertex by just giving it another NeuralNetConfiguration.Builder.
80
+
81
+
Method for solving the ODE can be configured:
82
+
83
+
```
84
+
new OdeVertex.Builder(...)
85
+
.odeConf(new FixedStep(
86
+
new DormandPrince54Solver(),
87
+
Nd4j.arange(0,2))) // Integrate between t = 0 and t = 1
88
+
```
89
+
90
+
Currently, the only ODE solver implementation which is integrated with Nd4j is [DormandPrince54Solver](./src/main/java/ode/solve/impl/DormandPrince54Solver.java),
91
+
It is however possible to use FirstOrderIntegrators from apache.commons:commons-math3 through [FirstOrderSolverAdapter](./src/main/java/ode/solve/commons/FirstOrderSolverAdapter.java)
92
+
at the cost of slower training and inference speed.
93
+
94
+
Time can also be input from another vertex in the graph:
95
+
```
96
+
new OdeVertex.Builder(...)
97
+
.odeConf(new InputStep(solverConf, 1)) // Number "1" refers to input "time" on the line below
98
+
.build(), "someLayer", "time");
99
+
```
100
+
101
+
Note that time must be a vector meaning it can not be minibatched; It has to be the same for all examples in a minibatch. This is because the implementation uses the minibatching approach from
102
+
section 6 in the paper where all examples in the batch are concatenated into one state. If one time sequence per example is desired this
103
+
can be achieved by using minibatch size of 1.
104
+
105
+
Gradients for loss with respect to time will be output from the vertex when using time as input but will be set to 0 by default to save computation. To have them computed, set needTimeGradient to true:
I have not seen these being used for anything in the original implementation and if used, some extra measure is most likely required to ensure that time is always strictly increasing or decreasing.
115
+
116
+
In either case, the minimum number of elements in the time vector is two. If more than two elements are given the output of the OdeVertex
117
+
will have one more dimension compared to the input (corresponding to each time element).
118
+
119
+
For example, if the graph in the OdeVertex is the function `f = dz/dt` and `time` is the sequence `t0, t1, ..., tN-1`
120
+
with `N > 2` then the output of the OdeVertex will be (an approximation of) the sequence `z(t0), z(t1), ... , z(tN-1)`.
121
+
Note that `z(t0)` is also the input to the OdeVertex.
122
+
123
+
The exact mapping to dimensions depends on the shape of the input. Currently the following mappings are supported:
|`B x D x H x W (conv 2D) `|`B x D x H x W x t (conv 3D)`|
130
+
25
131
26
132
### Prerequisites
27
133
28
134
Maven and GIT. Project uses ND4Js CUDA 10 backend as default which requires [CUDA 10](https://deeplearning4j.org/docs/latest/deeplearning4j-config-cudnn).
29
-
To use CPU backend instead, set the maven property backend-CPU (e.g. through the -P flag when running from command line).
135
+
To use CPU backend instead, set the maven property backend-CPU:
136
+
137
+
```
138
+
mvn install -P backend-CPU
139
+
```
30
140
31
141
## Contributing
32
142
33
143
All contributions are welcome. Head over to the issues page and either add a new issue or pick up and existing one.
Model "stem" is using the resnet option with zero resblocks after the downsampling layers. This indicates that neither the residual blocks nor the ode block seems to be contributing much to the performance in this simple experiment. Performance also varies about +-0.1% for each run of the same model.
34
+
35
+
## Spiral demo
36
+
37
+
Reimplementation of the spiral generation experiment from the [original repo](https://github.com/rtqichen/torchdiffeq/tree/master/examples).
38
+
39
+
To run the ODE net model, use the following command:
Furthermore, original implementation does not use the adjoint method for back propagation in this example and instead does backpropagation through the operations of the ODE solver. Backpropagation through the ODE solver is not supported in this project as of yet.
0 commit comments