Skip to content

Commit 81f12ef

Browse files
authored
Merge pull request #17 from DrChainsaw/spiraldemo
Add first set of classes for spiral demo
2 parents 43ca828 + ab89237 commit 81f12ef

File tree

137 files changed

+7836
-387
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

137 files changed

+7836
-387
lines changed

README.md

+118-8
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
# neuralODE4j
22

3-
Implementation of neural ordinary differential equations built for deeplearning4j.
3+
Travis [![Build Status](https://travis-ci.org/DrChainsaw/AmpControl.svg?branch=master)](https://travis-ci.org/DrChainsaw/NeuralODE4j)
4+
AppVeyor[![Build status](https://ci.appveyor.com/api/projects/status/wjdi11f4cmx32ir8?svg=true)](https://ci.appveyor.com/project/DrChainsaw/neuralode4j)
5+
6+
[![codebeat badge](https://codebeat.co/badges/d9e719b4-5465-4f08-9c14-f924691cdd86)](https://codebeat.co/projects/github-com-drchainsaw-neuralode4j-master)
7+
[![Codacy Badge](https://api.codacy.com/project/badge/Grade/d491774f94944895b6aa3e22b7aae8b3)](https://www.codacy.com/app/DrChainsaw/neuralODE4j?utm_source=github.com&utm_medium=referral&utm_content=DrChainsaw/neuralODE4j&utm_campaign=Badge_Grade)
8+
[![Maintainability](https://api.codeclimate.com/v1/badges/c0d216da01a0c8b8d615/maintainability)](https://codeclimate.com/github/DrChainsaw/neuralODE4j/maintainability)
9+
[![Test Coverage](https://api.codeclimate.com/v1/badges/c0d216da01a0c8b8d615/test_coverage)](https://codeclimate.com/github/DrChainsaw/neuralODE4j/test_coverage)
10+
11+
Implementation of neural Ordinary Differential Equations (ODE) built for [deeplearning4j](https://deeplearning4j.org/).
412

513
[[Arxiv](https://arxiv.org/abs/1806.07366)]
614

715
[[Pytorch repo by paper authors](https://github.com/rtqichen/torchdiffeq)]
816

9-
NOTE: This is very much a work in progress and given that I haven't touched a differential equation since school chances
10-
are that there are conceptual misunderstandings.
11-
12-
The performance of the MNIST example is in line with the results presented in the paper, but given the simplicity of that dataset this is no guarantee of correct implementation.
17+
[[Very good blog post](https://julialang.org/blog/2019/01/fluxdiffeq)]
1318

1419
## Getting Started
1520

@@ -21,20 +26,125 @@ cd neuralODE4j
2126
mvn install
2227
```
2328

24-
Currently only the MNIST toy experiment from the paper is implemented [[link]](./src/main/java/examples)
29+
I will try to create a maven artifact whenever I find the time for it. Please file an issue for this if you are interested.
30+
31+
Implementations of the MNIST and spiral generation toy experiments from the paper can be found under examples [[link]](./src/main/java/examples)
32+
33+
## Usage
34+
35+
The class [OdeVertex](./src/main/java/ode/vertex/conf/OdeVertex.java) is used to add an arbitrary graph of Layers or GraphVertices as an ODE block in a ComputationGraph.
36+
37+
OdeVertex extends GraphVertex and can be added to a GraphBuilder just as any other vertex. It has a similar API as GraphBuilder for adding
38+
layers and vertices.
39+
40+
Example:
41+
```
42+
final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder()
43+
.graphBuilder()
44+
.addInputs("input")
45+
.setInputTypes(InputType.convolutional(9, 9, 3))
46+
.addLayer("normalLayer0",
47+
new Convolution2D.Builder(3, 3)
48+
.nOut(32)
49+
.convolutionMode(ConvolutionMode.Same).build(), "input")
50+
51+
// Add an ODE block called "odeBlock" to the graph.
52+
.addVertex("odeBlock",
53+
new OdeVertex.Builder(new NeuralNetConfiguration.Builder(), "odeLayer0", new BatchNormalization.Builder().build())
54+
55+
// OdeVertex has a similar API as GraphBuilder for adding new layers/vertices to the OdeBlock
56+
.addLayer("odeLayer1", new Convolution2D.Builder(3, 3)
57+
.nOut(32)
58+
.convolutionMode(ConvolutionMode.Same).build(), "odeLayer0")
59+
60+
// Add more layers and vertices as desired
61+
62+
// Build the OdeVertex. The resulting "inner graph" will be treated as an ODE
63+
.build(), "normalLayer0")
64+
65+
// Layers/vertices can be added to the graph after the ODE block
66+
.addLayer("normalLayer1", new BatchNormalization.Builder().build(), "odeBlock")
67+
.setOutputs("output")
68+
.addLayer("output", new CnnLossLayer(), "normalLayer1")
69+
.build());
70+
```
71+
72+
An inherent constraint to the method itself is that the output of the last layer in the OdeVertex must have the exact same
73+
shape as the input to the first layer in the OdeVertex.
74+
75+
Note that OdeVertex.Builder requires a NeuralNetConfiguration.Builder as constructor input. This is because DL4J does not set graph wise
76+
default values for things like updaters and weight initialization for vertices so the only way to apply them to the
77+
Layers of the OdeVertex is to pass in the global configuration. Putting it as a required constructor argument will
78+
hopefully make this harder to forget. It is of course possible to have a separate set of default values for the layers
79+
of the OdeVertex by just giving it another NeuralNetConfiguration.Builder.
80+
81+
Method for solving the ODE can be configured:
82+
83+
```
84+
new OdeVertex.Builder(...)
85+
.odeConf(new FixedStep(
86+
new DormandPrince54Solver(),
87+
Nd4j.arange(0,2))) // Integrate between t = 0 and t = 1
88+
```
89+
90+
Currently, the only ODE solver implementation which is integrated with Nd4j is [DormandPrince54Solver](./src/main/java/ode/solve/impl/DormandPrince54Solver.java),
91+
It is however possible to use FirstOrderIntegrators from apache.commons:commons-math3 through [FirstOrderSolverAdapter](./src/main/java/ode/solve/commons/FirstOrderSolverAdapter.java)
92+
at the cost of slower training and inference speed.
93+
94+
Time can also be input from another vertex in the graph:
95+
```
96+
new OdeVertex.Builder(...)
97+
.odeConf(new InputStep(solverConf, 1)) // Number "1" refers to input "time" on the line below
98+
.build(), "someLayer", "time");
99+
```
100+
101+
Note that time must be a vector meaning it can not be minibatched; It has to be the same for all examples in a minibatch. This is because the implementation uses the minibatching approach from
102+
section 6 in the paper where all examples in the batch are concatenated into one state. If one time sequence per example is desired this
103+
can be achieved by using minibatch size of 1.
104+
105+
Gradients for loss with respect to time will be output from the vertex when using time as input but will be set to 0 by default to save computation. To have them computed, set needTimeGradient to true:
106+
107+
```
108+
final boolean needTimeGradient = true;
109+
new OdeVertex.Builder(...)
110+
.odeConf(new InputStep(solverConf, 1, true, needTimeGradient))
111+
.build(), "someLayer", "time");
112+
```
113+
114+
I have not seen these being used for anything in the original implementation and if used, some extra measure is most likely required to ensure that time is always strictly increasing or decreasing.
115+
116+
In either case, the minimum number of elements in the time vector is two. If more than two elements are given the output of the OdeVertex
117+
will have one more dimension compared to the input (corresponding to each time element).
118+
119+
For example, if the graph in the OdeVertex is the function `f = dz/dt` and `time` is the sequence `t0, t1, ..., tN-1`
120+
with `N > 2` then the output of the OdeVertex will be (an approximation of) the sequence `z(t0), z(t1), ... , z(tN-1)`.
121+
Note that `z(t0)` is also the input to the OdeVertex.
122+
123+
The exact mapping to dimensions depends on the shape of the input. Currently the following mappings are supported:
124+
125+
| Input shape | Output shape |
126+
|---------------------------|-------------------------------|
127+
| `B x H (dense/FF)` | `B x H x t (RNN)` |
128+
| `B x H x T(RNN)` | `Not supported` |
129+
| `B x D x H x W (conv 2D) `| `B x D x H x W x t (conv 3D)`|
130+
25131

26132
### Prerequisites
27133

28134
Maven and GIT. Project uses ND4Js CUDA 10 backend as default which requires [CUDA 10](https://deeplearning4j.org/docs/latest/deeplearning4j-config-cudnn).
29-
To use CPU backend instead, set the maven property backend-CPU (e.g. through the -P flag when running from command line).
135+
To use CPU backend instead, set the maven property backend-CPU:
136+
137+
```
138+
mvn install -P backend-CPU
139+
```
30140

31141
## Contributing
32142

33143
All contributions are welcome. Head over to the issues page and either add a new issue or pick up and existing one.
34144

35145
## Versioning
36146

37-
TBD
147+
TBD.
38148

39149
## Authors
40150

pom.xml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
<groupId>io.github.drchainsaw</groupId>
88
<artifactId>neuralODE4j</artifactId>
9-
<version>0.0.1-SNAPSHOT</version>
9+
<version>0.8.0</version>
1010

1111
<properties>
1212

src/main/java/examples/README.md

+32-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Examples
22

3-
# MNIST
3+
## MNIST
44

55
Reimplementation of the MNIST experiment from the [original repo](https://github.com/rtqichen/torchdiffeq/tree/master/examples).
66

@@ -16,6 +16,11 @@ mvn exec:java -Dexec.mainClass="examples.mnist.Main" -Dexec.args="odenet"
1616

1717
Running from the IDE is also possible in which case resnet/odenet must be set as program arguments.
1818

19+
Use -help for full list of command line arguments:
20+
21+
```
22+
mvn exec:java -Dexec.mainClass="examples.mnist.Main" -Dexec.args="-help"
23+
```
1924

2025
Performance (approx):
2126

@@ -26,3 +31,29 @@ Performance (approx):
2631
| stem | 0.5% |
2732

2833
Model "stem" is using the resnet option with zero resblocks after the downsampling layers. This indicates that neither the residual blocks nor the ode block seems to be contributing much to the performance in this simple experiment. Performance also varies about +-0.1% for each run of the same model.
34+
35+
## Spiral demo
36+
37+
Reimplementation of the spiral generation experiment from the [original repo](https://github.com/rtqichen/torchdiffeq/tree/master/examples).
38+
39+
To run the ODE net model, use the following command:
40+
41+
```
42+
mvn exec:java -Dexec.mainClass="examples.spiral.Main" -Dexec.args="odenet"
43+
```
44+
45+
Running from the IDE is also possible in which case odenet must be set as program arguments.
46+
47+
Use -help for full list of command line arguments:
48+
49+
```
50+
mvn exec:java -Dexec.mainClass="examples.spiral.Main" -Dexec.args="-help"
51+
```
52+
53+
Note that this example tends to run faster on CPU than on GPU, probably due to the relatively low number of parameters. Example:
54+
55+
```
56+
mvn -P backend-CPU exec:java -Dexec.mainClass="examples.spiral.Main" -Dexec.args="odenet"
57+
```
58+
59+
Furthermore, original implementation does not use the adjoint method for back propagation in this example and instead does backpropagation through the operations of the ODE solver. Backpropagation through the ODE solver is not supported in this project as of yet.

src/main/java/examples/mnist/Main.java

+9-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ class Main {
3737

3838
private static final Logger log = LoggerFactory.getLogger(Main.class);
3939

40+
@Parameter(names = {"-help", "-h"}, description = "Prints help message")
41+
private boolean help = false;
42+
4043
@Parameter(names = "-trainBatchSize", description = "Batch size to use for training")
4144
private int trainBatchSize = 128;
4245

@@ -85,6 +88,11 @@ private static Main parseArgs(String[] args) {
8588
JCommander jCommander = parbuilder.build();
8689
jCommander.parse(args);
8790

91+
if(main.help) {
92+
jCommander.usage();
93+
System.exit(0);
94+
}
95+
8896
final ModelFactory factory = modelCommands.get(jCommander.getParsedCommand());
8997

9098
main.init(factory.create(), factory.name());
@@ -103,7 +111,7 @@ private void init(ComputationGraph model, String modelName) {
103111
}
104112

105113
private void addListeners() {
106-
final File savedir = new File("savedmodels" + File.separator + modelName);
114+
final File savedir = new File("savedmodels" + File.separator + "MNIST" + File.separator + modelName);
107115
log.info("Models will be saved in: " + savedir.getAbsolutePath());
108116
savedir.mkdirs();
109117
model.addListeners(

src/main/java/examples/mnist/OdeNetModel.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
package examples.mnist;
22

33
import com.beust.jcommander.Parameter;
4+
import com.beust.jcommander.Parameters;
45
import com.beust.jcommander.ParametersDelegate;
56
import ode.solve.api.FirstOrderSolverConf;
67
import ode.solve.conf.DormandPrince54Solver;
78
import ode.solve.conf.SolverConfig;
89
import ode.vertex.conf.OdeVertex;
10+
import ode.vertex.conf.helper.FixedStep;
911
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder;
1012
import org.deeplearning4j.nn.graph.ComputationGraph;
13+
import org.nd4j.linalg.factory.Nd4j;
1114
import org.slf4j.Logger;
1215
import org.slf4j.LoggerFactory;
1316
import util.listen.step.Mask;
@@ -20,6 +23,7 @@
2023
*
2124
* @author Christian Skarby
2225
*/
26+
@Parameters(commandDescription = "Configuration for image classification using an ODE block")
2327
public class OdeNetModel implements ModelFactory {
2428

2529
private static final Logger log = LoggerFactory.getLogger(OdeNetModel.class);
@@ -77,7 +81,7 @@ private String addOdeBlock(String prev, FirstOrderSolverConf solver) {
7781
conv3x3Same(nrofKernels), "normSecond")
7882
.addLayer("normThird",
7983
norm(nrofKernels), "convSecond")
80-
.odeSolver(solver)
84+
.odeConf(new FixedStep(solver, Nd4j.arange(2)))
8185
.build(), prev);
8286
return "odeBlock";
8387
}

src/main/java/examples/mnist/ResNetReferenceModel.java

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package examples.mnist;
22

33
import com.beust.jcommander.Parameter;
4+
import com.beust.jcommander.Parameters;
45
import com.beust.jcommander.ParametersDelegate;
56
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder;
67
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
@@ -15,6 +16,7 @@
1516
*
1617
* @author Christian Skarby
1718
*/
19+
@Parameters(commandDescription = "Configuration for image classification using a number of residual blocks")
1820
public class ResNetReferenceModel implements ModelFactory {
1921

2022
private static final Logger log = LoggerFactory.getLogger(ResNetReferenceModel.class);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package examples.spiral;
2+
3+
import org.nd4j.linalg.api.ndarray.INDArray;
4+
import org.nd4j.linalg.dataset.api.MultiDataSet;
5+
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
6+
import org.nd4j.linalg.factory.Nd4j;
7+
import org.nd4j.linalg.indexing.INDArrayIndex;
8+
import org.nd4j.linalg.indexing.NDArrayIndex;
9+
10+
/**
11+
* Adds a label for KLD loss
12+
*
13+
* @author Christian Skarby
14+
*/
15+
public class AddKLDLabel implements MultiDataSetPreProcessor {
16+
17+
private final double mean;
18+
private final double logvar;
19+
private final long nrofLatentDims;
20+
21+
public AddKLDLabel(double mean, double var, long nrofLatentDims) {
22+
this.mean = mean;
23+
this.logvar = Math.log(var);
24+
this.nrofLatentDims = nrofLatentDims;
25+
}
26+
27+
@Override
28+
public void preProcess(MultiDataSet multiDataSet) {
29+
final INDArray label0 = multiDataSet.getLabels(0);
30+
final long batchSize = label0.size(0);
31+
final INDArray kldLabel = Nd4j.zeros(batchSize, 2*nrofLatentDims);
32+
kldLabel.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, nrofLatentDims)}, mean);
33+
kldLabel.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(nrofLatentDims, 2*nrofLatentDims)}, logvar);
34+
multiDataSet.setLabels(new INDArray[]{label0, kldLabel});
35+
}
36+
}
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package examples.spiral;
2+
3+
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
4+
5+
/**
6+
* A simple block of layers
7+
*
8+
* @author Christian Skarby
9+
*/
10+
interface Block {
11+
12+
/**
13+
* Add layers to given builder
14+
* @param builder Builder to add layers to
15+
* @param prev previous layers
16+
* @return name of last layer added
17+
*/
18+
String add(ComputationGraphConfiguration.GraphBuilder builder, String... prev);
19+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package examples.spiral;
2+
3+
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
4+
import org.deeplearning4j.nn.conf.graph.PreprocessorVertex;
5+
import org.deeplearning4j.nn.conf.layers.DenseLayer;
6+
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
7+
import org.nd4j.linalg.activations.impl.ActivationIdentity;
8+
import org.nd4j.linalg.activations.impl.ActivationReLU;
9+
10+
/**
11+
* Simple decoder using {@link DenseLayer}s. Also uses a {@link FeedForwardToRnnPreProcessor} as it assumes 3D input.
12+
*
13+
* @author Christian Skarby
14+
*/
15+
public class DenseDecoderBlock implements Block {
16+
17+
private final long nrofHidden;
18+
private final long nrofOutputs;
19+
20+
public DenseDecoderBlock(long nrofHidden, long nrofOutputs) {
21+
this.nrofHidden = nrofHidden;
22+
this.nrofOutputs = nrofOutputs;
23+
}
24+
25+
@Override
26+
public String add(ComputationGraphConfiguration.GraphBuilder builder, String... prev) {
27+
builder
28+
.addLayer("dec0", new DenseLayer.Builder()
29+
.nOut(nrofHidden)
30+
.activation(new ActivationReLU())
31+
.build(), prev)
32+
.addLayer("dec1", new DenseLayer.Builder()
33+
.nOut(nrofOutputs)
34+
.activation(new ActivationIdentity())
35+
.build(), "dec0")
36+
.addVertex("decodedOutput",
37+
new PreprocessorVertex(
38+
new FeedForwardToRnnPreProcessor()),
39+
"dec1");
40+
41+
return "decodedOutput";
42+
}
43+
}

0 commit comments

Comments
 (0)