-
-
Notifications
You must be signed in to change notification settings - Fork 213
Artificial Neural Networks
Anders Peterson edited this page Mar 19, 2019
·
4 revisions
An updated version of this page can be found at http://www.ojalgo.org/2018/09/introducing-artificial-neural-networks-with-ojalgo/
To demonstrate how to use ojAlgo's Neural Network feature this example constructs a network to interpret the THE MNIST DATABASE of handwritten digits.
To run the tests you have to download that data yourself and update the paths in the example code.
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import javax.imageio.ImageIO;
import org.ojalgo.ann.ANN;
import org.ojalgo.ann.ArtificialNeuralNetwork;
import org.ojalgo.ann.NetworkBuilder;
import org.ojalgo.array.ArrayAnyD;
import org.ojalgo.constant.PrimitiveMath;
import org.ojalgo.function.PrimitiveFunction;
import org.ojalgo.matrix.store.PrimitiveDenseStore;
import org.ojalgo.netio.BasicLogger;
import org.ojalgo.netio.IDX;
import org.ojalgo.structure.MatrixView;
/**
* An example of how to build, train and use artificial neural networks with ojAlgo.
*/
public class TrainingANN {
static final String OUTPUT_TEST_IMAGES = "/Users/apete/Developer/data/images/test/";
static final String OUTPUT_TRAINING_IMAGES = "/Users/apete/Developer/data/images/training/";
static final String TEST_IMAGES = "/Users/apete/Developer/data/t10k-images-idx3-ubyte";
static final String TEST_LABELS = "/Users/apete/Developer/data/t10k-labels-idx1-ubyte";
static final String TRAINING_IMAGES = "/Users/apete/Developer/data/train-images-idx3-ubyte";
static final String TRAINING_LABELS = "/Users/apete/Developer/data/train-labels-idx1-ubyte";
public static void main(final String[] args) throws IOException {
boolean generateImages = false;
int numberToPrint = 10;
NetworkBuilder builder = ArtificialNeuralNetwork.builder(28 * 28, 200, 10);
builder.activators(ANN.Activator.SIGMOID, ANN.Activator.SOFTMAX).error(ANN.Error.CROSS_ENTROPY).randomise().rate(0.05);
ArrayAnyD<Double> trainingLabels = IDX.parse(TRAINING_LABELS);
ArrayAnyD<Double> trainingImages = IDX.parse(TRAINING_IMAGES);
trainingImages.modifyAll(PrimitiveFunction.DIVIDE.second(255));
PrimitiveDenseStore input = PrimitiveDenseStore.FACTORY.makeZero(28, 28);
PrimitiveDenseStore output = PrimitiveDenseStore.FACTORY.makeZero(10, 1);
for (int l = 0; l < 10; l++) {
for (MatrixView<Double> imageData : trainingImages.matrices()) {
long imageIndex = imageData.index();
input.fillMatching(imageData);
long digitIndex = trainingLabels.longValue(imageIndex);
output.fillAll(PrimitiveMath.ZERO);
output.set(digitIndex, PrimitiveMath.ONE);
builder.train(input, output);
if (generateImages) {
TrainingANN.generateImage(imageData, digitIndex, OUTPUT_TRAINING_IMAGES);
}
}
}
ArtificialNeuralNetwork network = builder.get();
ArrayAnyD<Double> testLabels = IDX.parse(TEST_LABELS);
ArrayAnyD<Double> testImages = IDX.parse(TEST_IMAGES);
testImages.modifyAll(PrimitiveFunction.DIVIDE.second(255));
int right = 0;
int wrong = 0;
for (MatrixView<Double> imageData : testImages.matrices()) {
long imageIndex = imageData.index();
input.fillMatching(imageData);
network.invoke(input).supplyTo(output);
long expected = testLabels.longValue(imageIndex);
long actual = output.indexOfLargest();
if (actual == expected) {
right++;
} else {
wrong++;
}
if (imageIndex < numberToPrint) {
BasicLogger.debug("");
BasicLogger.debug("Image {}: {} <=> {}", imageIndex, expected, actual);
IDX.print(input, BasicLogger.DEBUG, true, 1.0);
}
if (generateImages) {
TrainingANN.generateImage(imageData, expected, OUTPUT_TEST_IMAGES);
}
}
BasicLogger.debug("");
BasicLogger.debug("=========================================================");
BasicLogger.debug("Error rate: {}", (double) wrong / (double) (right + wrong));
}
private static void generateImage(MatrixView<Double> imageData, long imageLabel, String directoryPath) throws IOException {
int numberOfRows = (int) imageData.countRows();
int numberOfColumns = (int) imageData.countColumns();
BufferedImage image = new BufferedImage(numberOfRows, numberOfColumns, BufferedImage.TYPE_INT_ARGB);
for (int i = 0; i < imageData.countRows(); i++) {
for (int j = 0; j < imageData.countColumns(); j++) {
// The colours are stored inverted in the IDX-files (255 means "ink"
// and 0 means "no ink". In computer graphics 255 usually means "white"
// and 0 "black".) In addition the image data has already been rescaled
// to be in the range [0,1]. That's why...
int gray = (int) (255.0 * (1.0 - imageData.doubleValue(i, j)));
int rgb = 0xFF000000 | (gray << 16) | (gray << 8) | gray;
image.setRGB(i, j, rgb);
}
}
String fullPathAndName = directoryPath + (100000 + imageData.index()) + "_" + imageLabel + ".png";
File outputfile = new File(fullPathAndName);
ImageIO.write(image, "png", outputfile);
}
}
##Console output
Image 0: 7 <=> 7
X++
XXXXXXXXXXXXXXX
+ ++XXXXXXXXXX+
XX+
XX
XX
+XX
XX
+XX
XX
+XX
XX
XX+
XXX
XX
XX+
+XX
XXX
+XXX
+XX
Image 1: 2 <=> 2
++XXX++
+XXXXXXX
+XXXX+XXX+
XXX XX+
XX +XX
XXX
+XXX
XXX
+XX+
XXX+
XXX
XXX+
XXX
XXX+
+XXX
XXX
XXX ++++
XXXXXXXX+++XXXXXXX+
XXXXXXXXXXXXXXX+++
+++++XXX+++
Image 2: 1 <=> 1
X+
+X
+X
X+
X
XX
XX
+XX
+X
XX
+X+
XX
XX
+X+
+X+
XX
XX
+XX
XX+
XX
Image 3: 0 <=> 0
+XX
XXX+
XXXX+
+XXXXX++
XXXXXXXXX
XXXXXXXXXXX
XXXXX+ +XX+
XXXXX+ XXX+
XXXX +XXX
XXX XXX
XX XXX+
+XX +XXX
XXX XXXX
XXX +XXXX
XXX XXXXXX
XXX XXXXXXXX
+XXXXXXXXXXX+
XXXXXXXXXX
+XXXXXX+
XXX
Image 4: 4 <=> 4
X
+X ++
XX +X
X+ +X
XX XX
+X XX
XX +XX
+X+ XX
XX +XX
X+ +X+
XX XX+
+XX++++++XXX+
XXXXXXXXXXX
+XX
+XX
+XX
+XX
+XX
XX+
X
Image 5: 1 <=> 1
X+
XXX
XX+
XXX
XXX
+XXX
XXX
XXX
+XX+
+XX
XXX
XX+
+XX
XXX
+XX+
+XX
XXX
XXX
+X+
++
Image 6: 4 <=> 4
X+
XX+ XX
XX+ +X+
XX XX
+X+ XX+
+XX +X+
+XX +XX
+XXXX +XXX+
XXXXXXXXXXXX
+XX+++ XX+
XX
+X
+XX
+X+
XX
+XX
XX+
XX++X+
XXXX+
X+
Image 7: 9 <=> 9
+X
+XXX
XXXX+
+XXXXXX+
XXXXXXXX
XX +XXXX
+X XXXX
XX+ XXX+
+XX ++XXXX
XXXXXX XX+
XXX+ +XX
++ XX+
+XX
+X+
XX
+XX
XX
X+
XX
X
Image 8: 5 <=> 5
++
+XXXXXXXX
+XXXXXXXXXXX
XXXXXXXXXXXX
+ XXX++
XX
XX+
+XX
+XX
+XX
XXX+
XXXXX++++ +
+XXXXXXXXXXXX+
XXXXXXXXXXXX+
XXXXXXXXX
+XX+ +XXX
XXXXXXX+
+XXXXXX
XXXX+
++
Image 9: 9 <=> 9
+XX+
+XXXXXXXX+
+XXXX++XXXXX+
+XXXX+ X XXXX
+XXX+ X XXX
XXX +XXXX+
XXX++XX+XXXXXXX
XXXXXXXXXXXXXX
+XX++XXXXX+
+XXX+
+XXX+
+XXXX
XXXX
+XXX
XXX+
XXXX
XXX
XXX+
XXX
+X+
=========================================================
Error rate: 0.0217