Skip to content

Commit e9c2988

Browse files
author
Stefan Hahmann
committed
Add (unfinished) PCADemo
1 parent e0fb205 commit e9c2988

2 files changed

Lines changed: 51 additions & 0 deletions

File tree

pom.xml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,14 @@
129129
<version>v2.6.4</version>
130130
</dependency>
131131

132+
<!-- PCA -->
133+
<!-- Example: https://spark.apache.org/docs/latest/mllib-dimensionality-reduction#principal-component-analysis-pca -->
134+
<dependency>
135+
<groupId>org.apache.spark</groupId>
136+
<artifactId>spark-mllib_2.13</artifactId>
137+
<version>3.5.3</version>
138+
</dependency>
139+
132140
<!-- Standardization for UMAP preprocessing -->
133141
<dependency>
134142
<groupId>org.apache.commons</groupId>
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package org.mastodon.mamut.feature.dimensionalityreduction.pca;
2+
3+
import java.util.ArrayList;
4+
import java.util.List;
5+
6+
import org.apache.spark.api.java.JavaRDD;
7+
import org.apache.spark.api.java.JavaSparkContext;
8+
import org.apache.spark.mllib.linalg.Matrix;
9+
import org.apache.spark.mllib.linalg.Vector;
10+
import org.apache.spark.mllib.linalg.Vectors;
11+
import org.apache.spark.mllib.linalg.distributed.RowMatrix;
12+
import org.mastodon.mamut.feature.dimensionalityreduction.PlotPoints;
13+
import org.mastodon.mamut.feature.dimensionalityreduction.RandomDataTools;
14+
15+
public class PCADemo
16+
{
17+
// TODO: check https://stackoverflow.com/questions/10604507/pca-implementation-in-java
18+
public static void main( String[] args )
19+
{
20+
double[][] inputData = RandomDataTools.generateSampleData();
21+
double[][] result = setUpPCA( inputData );
22+
PlotPoints.plot( inputData, result, resultValues -> resultValues[ 0 ] > 10 );
23+
}
24+
25+
static double[][] setUpPCA( double[][] inputData )
26+
{
27+
try (JavaSparkContext jsc = new JavaSparkContext( "local", "PCA" ))
28+
{
29+
List< Vector > data = new ArrayList<>();
30+
for ( final double[] row : inputData )
31+
data.add( Vectors.dense( row ) );
32+
JavaRDD< Vector > rows = jsc.parallelize( data );
33+
// Create a RowMatrix from JavaRDD<Vector>.
34+
RowMatrix rowMatrix = new RowMatrix( rows.rdd() );
35+
// Compute the top 2 principal components.
36+
// Principal components are stored in a local dense matrix.
37+
Matrix pc = rowMatrix.computePrincipalComponents( 2 );
38+
// Project the rows to the linear space spanned by the top 4 principal components.
39+
RowMatrix projected = rowMatrix.multiply( pc );
40+
return projected.rows().toJavaRDD().collect().stream().map( Vector::toArray ).toArray( double[][]::new );
41+
}
42+
}
43+
}

0 commit comments

Comments
 (0)