Skip to content

Artificial Neural Networks

Anders Peterson edited this page Mar 19, 2019 · 4 revisions

An updated version of this page can be found at http://www.ojalgo.org/2018/09/introducing-artificial-neural-networks-with-ojalgo/


To demonstrate how to use ojAlgo's Neural Network feature this example constructs a network to interpret the THE MNIST DATABASE of handwritten digits.

To run the tests you have to download that data yourself and update the paths in the example code.

Example code

import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;

import javax.imageio.ImageIO;

import org.ojalgo.ann.ANN;
import org.ojalgo.ann.ArtificialNeuralNetwork;
import org.ojalgo.ann.NetworkBuilder;
import org.ojalgo.array.ArrayAnyD;
import org.ojalgo.constant.PrimitiveMath;
import org.ojalgo.function.PrimitiveFunction;
import org.ojalgo.matrix.store.PrimitiveDenseStore;
import org.ojalgo.netio.BasicLogger;
import org.ojalgo.netio.IDX;
import org.ojalgo.structure.MatrixView;

/**
 * An example of how to build, train and use artificial neural networks with ojAlgo. 
 */
public class TrainingANN {

    static final String OUTPUT_TEST_IMAGES = "/Users/apete/Developer/data/images/test/";
    static final String OUTPUT_TRAINING_IMAGES = "/Users/apete/Developer/data/images/training/";
    static final String TEST_IMAGES = "/Users/apete/Developer/data/t10k-images-idx3-ubyte";
    static final String TEST_LABELS = "/Users/apete/Developer/data/t10k-labels-idx1-ubyte";
    static final String TRAINING_IMAGES = "/Users/apete/Developer/data/train-images-idx3-ubyte";
    static final String TRAINING_LABELS = "/Users/apete/Developer/data/train-labels-idx1-ubyte";

    public static void main(final String[] args) throws IOException {

        boolean generateImages = false;
        int numberToPrint = 10;

        NetworkBuilder builder = ArtificialNeuralNetwork.builder(28 * 28, 200, 10);

        builder.activators(ANN.Activator.SIGMOID, ANN.Activator.SOFTMAX).error(ANN.Error.CROSS_ENTROPY).randomise().rate(0.05);

        ArrayAnyD<Double> trainingLabels = IDX.parse(TRAINING_LABELS);
        ArrayAnyD<Double> trainingImages = IDX.parse(TRAINING_IMAGES);
        trainingImages.modifyAll(PrimitiveFunction.DIVIDE.second(255));

        PrimitiveDenseStore input = PrimitiveDenseStore.FACTORY.makeZero(28, 28);
        PrimitiveDenseStore output = PrimitiveDenseStore.FACTORY.makeZero(10, 1);

        for (int l = 0; l < 10; l++) { 
            for (MatrixView<Double> imageData : trainingImages.matrices()) {
                long imageIndex = imageData.index();

                input.fillMatching(imageData);

                long digitIndex = trainingLabels.longValue(imageIndex);
                output.fillAll(PrimitiveMath.ZERO);
                output.set(digitIndex, PrimitiveMath.ONE);

                builder.train(input, output);

                if (generateImages) {
                    TrainingANN.generateImage(imageData, digitIndex, OUTPUT_TRAINING_IMAGES);
                }
            }
        }

        ArtificialNeuralNetwork network = builder.get();

        ArrayAnyD<Double> testLabels = IDX.parse(TEST_LABELS);
        ArrayAnyD<Double> testImages = IDX.parse(TEST_IMAGES);
        testImages.modifyAll(PrimitiveFunction.DIVIDE.second(255));

        int right = 0;
        int wrong = 0;

        for (MatrixView<Double> imageData : testImages.matrices()) {
            long imageIndex = imageData.index();

            input.fillMatching(imageData);
            network.invoke(input).supplyTo(output);

            long expected = testLabels.longValue(imageIndex);
            long actual = output.indexOfLargest();

            if (actual == expected) {
                right++;
            } else {
                wrong++;
            }

            if (imageIndex < numberToPrint) {
                BasicLogger.debug("");
                BasicLogger.debug("Image {}: {} <=> {}", imageIndex, expected, actual);
                IDX.print(input, BasicLogger.DEBUG, true, 1.0);
            }

            if (generateImages) {
                TrainingANN.generateImage(imageData, expected, OUTPUT_TEST_IMAGES);
            }
        }

        BasicLogger.debug("");
        BasicLogger.debug("=========================================================");
        BasicLogger.debug("Error rate: {}", (double) wrong / (double) (right + wrong));
    }

    private static void generateImage(MatrixView<Double> imageData, long imageLabel, String directoryPath) throws IOException {

        int numberOfRows = (int) imageData.countRows();
        int numberOfColumns = (int) imageData.countColumns();

        BufferedImage image = new BufferedImage(numberOfRows, numberOfColumns, BufferedImage.TYPE_INT_ARGB);

        for (int i = 0; i < imageData.countRows(); i++) {
            for (int j = 0; j < imageData.countColumns(); j++) {
                // The colours are stored inverted in the IDX-files (255 means "ink"
                // and 0 means "no ink". In computer graphics 255 usually means "white"
                // and 0 "black".) In addition the image data has already been rescaled
                // to be in the range [0,1]. That's why...
                int gray = (int) (255.0 * (1.0 - imageData.doubleValue(i, j)));
                int rgb = 0xFF000000 | (gray << 16) | (gray << 8) | gray;
                image.setRGB(i, j, rgb);
            }
        }

        String fullPathAndName = directoryPath + (100000 + imageData.index()) + "_" + imageLabel + ".png";
        File outputfile = new File(fullPathAndName);

        ImageIO.write(image, "png", outputfile);
    }

}

##Console output

Image 0: 7 <=> 7
                            
                            
                            
                            
                            
                            
                            
       X++                  
      XXXXXXXXXXXXXXX       
       + ++XXXXXXXXXX+      
                   XX+      
                   XX       
                  XX        
                 +XX        
                 XX         
                +XX         
                XX          
               +XX          
               XX           
              XX+           
             XXX            
             XX             
            XX+             
           +XX              
           XXX              
          +XXX              
          +XX               
                            

Image 1: 2 <=> 2
                            
                            
                            
          ++XXX++           
         +XXXXXXX           
        +XXXX+XXX+          
        XXX    XX+          
        XX    +XX           
              XXX           
             +XXX           
             XXX            
            +XX+            
           XXX+             
           XXX              
          XXX+              
          XXX               
         XXX+               
        +XXX                
        XXX                 
        XXX           ++++  
        XXXXXXXX+++XXXXXXX+ 
        XXXXXXXXXXXXXXX+++  
         +++++XXX+++        
                            
                            
                            
                            
                            

Image 2: 1 <=> 1
                            
                            
                            
                            
                 X+         
                +X          
                +X          
                X+          
                X           
               XX           
               XX           
              +XX           
              +X            
              XX            
             +X+            
             XX             
             XX             
            +X+             
            +X+             
            XX              
            XX              
           +XX              
           XX+              
           XX               
                            
                            
                            
                            

Image 3: 0 <=> 0
                            
                            
                            
                            
             +XX            
             XXX+           
            XXXX+           
          +XXXXX++          
          XXXXXXXXX         
         XXXXXXXXXXX        
         XXXXX+  +XX+       
        XXXXX+    XXX+      
        XXXX      +XXX      
        XXX        XXX      
        XX         XXX+     
       +XX        +XXX      
       XXX        XXXX      
       XXX      +XXXX       
       XXX     XXXXXX       
       XXX  XXXXXXXX        
       +XXXXXXXXXXX+        
        XXXXXXXXXX          
         +XXXXXX+           
           XXX              
                            
                            
                            
                            

Image 4: 4 <=> 4
                            
                            
                            
                            
                            
           X                
          +X       ++       
          XX       +X       
          X+       +X       
         XX        XX       
        +X         XX       
        XX        +XX       
       +X+        XX        
       XX        +XX        
       X+        +X+        
       XX        XX+        
       +XX++++++XXX+        
        XXXXXXXXXXX         
                +XX         
                +XX         
                +XX         
                +XX         
                +XX         
                XX+         
                X           
                            
                            
                            

Image 5: 1 <=> 1
                            
                            
                            
                            
                            
                 X+         
                XXX         
                XX+         
               XXX          
               XXX          
              +XXX          
              XXX           
              XXX           
             +XX+           
             +XX            
             XXX            
             XX+            
            +XX             
            XXX             
           +XX+             
           +XX              
           XXX              
           XXX              
           +X+              
            ++              
                            
                            
                            

Image 6: 4 <=> 4
                            
                            
                            
                            
                            
          X+                
         XX+         XX     
        XX+         +X+     
        XX          XX      
       +X+         XX+      
      +XX         +X+       
      +XX        +XX        
      +XXXX    +XXX+        
       XXXXXXXXXXXX         
         +XX+++ XX+         
                XX          
               +X           
              +XX           
              +X+           
              XX            
             +XX            
             XX+            
             XX++X+         
             XXXX+          
              X+            
                            
                            
                            

Image 7: 9 <=> 9
                            
                            
                            
                            
                            
                            
            +X              
          +XXX              
          XXXX+             
         +XXXXXX+           
         XXXXXXXX           
         XX  +XXXX          
        +X    XXXX          
        XX+   XXX+          
        +XX ++XXXX          
         XXXXXX XX+         
          XXX+  +XX         
           ++    XX+        
                 +XX        
                  +X+       
                   XX       
                   +XX      
                    XX      
                     X+     
                     XX     
                      X     
                            
                            

Image 8: 5 <=> 5
                            
                            
                            
                            
                      ++    
                +XXXXXXXX   
             +XXXXXXXXXXX   
             XXXXXXXXXXXX   
         +   XXX++          
        XX                  
       XX+                  
      +XX                   
     +XX                    
     +XX                    
     XXX+                   
     XXXXX++++  +           
     +XXXXXXXXXXXX+         
       XXXXXXXXXXXX+        
           XXXXXXXXX        
           +XX+ +XXX        
            XXXXXXX+        
            +XXXXXX         
             XXXX+          
              ++            
                            
                            
                            
                            

Image 9: 9 <=> 9
                            
                            
                            
                            
                            
                            
                            
               +XX+         
            +XXXXXXXX+      
          +XXXX++XXXXX+     
        +XXXX+   X XXXX     
       +XXX+     X  XXX     
       XXX       +XXXX+     
       XXX++XX+XXXXXXX      
       XXXXXXXXXXXXXX       
         +XX++XXXXX+        
              +XXX+         
             +XXX+          
            +XXXX           
            XXXX            
           +XXX             
           XXX+             
          XXXX              
          XXX               
         XXX+               
         XXX                
         +X+                
                            

=========================================================
Error rate: 0.0217
Clone this wiki locally