|
| 1 | +package org.mastodon.mamut.feature.dimensionalityreduction.pca; |
| 2 | + |
| 3 | +import java.util.ArrayList; |
| 4 | +import java.util.List; |
| 5 | + |
| 6 | +import org.apache.spark.api.java.JavaRDD; |
| 7 | +import org.apache.spark.api.java.JavaSparkContext; |
| 8 | +import org.apache.spark.mllib.linalg.Matrix; |
| 9 | +import org.apache.spark.mllib.linalg.Vector; |
| 10 | +import org.apache.spark.mllib.linalg.Vectors; |
| 11 | +import org.apache.spark.mllib.linalg.distributed.RowMatrix; |
| 12 | +import org.mastodon.mamut.feature.dimensionalityreduction.PlotPoints; |
| 13 | +import org.mastodon.mamut.feature.dimensionalityreduction.RandomDataTools; |
| 14 | + |
| 15 | +public class PCADemo |
| 16 | +{ |
| 17 | + // TODO: check https://stackoverflow.com/questions/10604507/pca-implementation-in-java |
| 18 | + public static void main( String[] args ) |
| 19 | + { |
| 20 | + double[][] inputData = RandomDataTools.generateSampleData(); |
| 21 | + double[][] result = setUpPCA( inputData ); |
| 22 | + PlotPoints.plot( inputData, result, resultValues -> resultValues[ 0 ] > 10 ); |
| 23 | + } |
| 24 | + |
| 25 | + static double[][] setUpPCA( double[][] inputData ) |
| 26 | + { |
| 27 | + try (JavaSparkContext jsc = new JavaSparkContext( "local", "PCA" )) |
| 28 | + { |
| 29 | + List< Vector > data = new ArrayList<>(); |
| 30 | + for ( final double[] row : inputData ) |
| 31 | + data.add( Vectors.dense( row ) ); |
| 32 | + JavaRDD< Vector > rows = jsc.parallelize( data ); |
| 33 | + // Create a RowMatrix from JavaRDD<Vector>. |
| 34 | + RowMatrix rowMatrix = new RowMatrix( rows.rdd() ); |
| 35 | + // Compute the top 2 principal components. |
| 36 | + // Principal components are stored in a local dense matrix. |
| 37 | + Matrix pc = rowMatrix.computePrincipalComponents( 2 ); |
| 38 | + // Project the rows to the linear space spanned by the top 4 principal components. |
| 39 | + RowMatrix projected = rowMatrix.multiply( pc ); |
| 40 | + return projected.rows().toJavaRDD().collect().stream().map( Vector::toArray ).toArray( double[][]::new ); |
| 41 | + } |
| 42 | + } |
| 43 | +} |
0 commit comments