Skip to content

Commit d7ca475

Browse files
author
Stefan Hahmann
committed
Add demo applications and a (yet) failing unit test for the UMAP implementation of the smile library
1 parent d4b9793 commit d7ca475

File tree

4 files changed

+278
-0
lines changed

4 files changed

+278
-0
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package org.mastodon.mamut.feature.dimensionalityreduction.umap;
2+
3+
import java.util.Arrays;
4+
5+
import org.mastodon.mamut.feature.dimensionalityreduction.PlotPoints;
6+
import org.mastodon.mamut.feature.dimensionalityreduction.RandomDataTools;
7+
8+
import smile.manifold.UMAP;
9+
10+
public class UmapSmileDemo
11+
{
12+
public static void main( final String[] args )
13+
{
14+
int numCluster1 = 50;
15+
int numCluster2 = 100;
16+
double[][] sampleData = RandomDataTools.generateSampleData( numCluster1, numCluster2 );
17+
UMAP umap = setUpUmap( sampleData );
18+
double[][] umapResult = umap.coordinates;
19+
double[][] result = Arrays.stream( umapResult ).map( row -> Arrays.stream( row ).map( value -> value * 10d ).toArray() ) // scale up
20+
.toArray( double[][]::new );
21+
PlotPoints.plot( sampleData, result, resultValues -> resultValues[ 0 ] > 1 );
22+
}
23+
24+
static UMAP setUpUmap( final double[][] sampleData )
25+
{
26+
int iterations = sampleData.length < 10_000 ? 500 : 200; // https://github.com/lmcinnes/umap/blob/a012b9d8751d98b94935ca21f278a54b3c3e1b7f/umap/umap_.py#L1073
27+
double minDist = 0.1;
28+
int nNeighbors = 15;
29+
return UMAP.of( sampleData, nNeighbors, 2, iterations, 1, minDist, 1.0, 5, 1 );
30+
}
31+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package org.mastodon.mamut.feature.dimensionalityreduction.umap;
2+
3+
import java.io.File;
4+
import java.io.FileReader;
5+
import java.io.IOException;
6+
import java.util.ArrayList;
7+
import java.util.Arrays;
8+
import java.util.List;
9+
10+
import com.opencsv.CSVParserBuilder;
11+
import com.opencsv.CSVReader;
12+
import com.opencsv.CSVReaderBuilder;
13+
import com.opencsv.exceptions.CsvValidationException;
14+
15+
import org.mastodon.mamut.feature.dimensionalityreduction.PlotPoints;
16+
import org.mastodon.mamut.feature.dimensionalityreduction.util.StandardScaler;
17+
18+
import smile.manifold.UMAP;
19+
20+
public class UmapSmileDemoIris
21+
{
22+
public static void main( String[] args ) throws IOException, CsvValidationException
23+
{
24+
File file = new File( "src/test/resources/org/mastodon/mamut/feature/dimensionalityreduction/iris.tsv" );
25+
CSVParserBuilder parserBuilder = new CSVParserBuilder().withSeparator( '\t' );
26+
CSVReaderBuilder builder = new CSVReaderBuilder( new FileReader( file ) ).withCSVParser( parserBuilder.build() );
27+
try (CSVReader reader = builder.build())
28+
{
29+
List< double[] > data = new ArrayList<>();
30+
reader.readNext(); // skip header
31+
for ( String[] nextLine; ( nextLine = reader.readNext() ) != null; )
32+
{
33+
double[] values = new double[ nextLine.length - 1 ];
34+
for ( int i = 1; i < nextLine.length; i++ )
35+
values[ i - 1 ] = Double.parseDouble( nextLine[ i ] );
36+
data.add( values );
37+
}
38+
double[][] inputData = data.toArray( new double[ data.size() ][ data.get( 0 ).length ] );
39+
StandardScaler.standardizeColumns( inputData );
40+
UMAP umap = setUpUmap( inputData );
41+
double[][] result = umap.coordinates;
42+
result = Arrays.stream( result ).map( row -> Arrays.stream( row ).map( value -> value * 10d ).toArray() ) // scale up
43+
.toArray( double[][]::new );
44+
PlotPoints.plot( null, result, null );
45+
}
46+
}
47+
48+
static UMAP setUpUmap( double[][] data )
49+
{
50+
int iterations = data.length < 10_000 ? 500 : 200; // https://github.com/lmcinnes/umap/blob/a012b9d8751d98b94935ca21f278a54b3c3e1b7f/umap/umap_.py#L1073
51+
double minDist = 0.1;
52+
int nNeighbors = 15;
53+
return UMAP.of( data, nNeighbors, 2, iterations, 1, minDist, 1.0, 5, 1 );
54+
}
55+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package org.mastodon.mamut.feature.dimensionalityreduction.umap;
2+
3+
import java.io.File;
4+
import java.io.FileReader;
5+
import java.io.IOException;
6+
import java.util.ArrayList;
7+
import java.util.Arrays;
8+
import java.util.List;
9+
10+
import com.opencsv.CSVParserBuilder;
11+
import com.opencsv.CSVReader;
12+
import com.opencsv.CSVReaderBuilder;
13+
import com.opencsv.exceptions.CsvValidationException;
14+
15+
import org.mastodon.mamut.feature.dimensionalityreduction.PlotPoints;
16+
import org.mastodon.mamut.feature.dimensionalityreduction.util.StandardScaler;
17+
18+
import smile.manifold.UMAP;
19+
20+
public class UmapSmileDemoTgmmMini
21+
{
22+
public static void main( String[] args ) throws IOException, CsvValidationException
23+
{
24+
File file = new File( "src/test/resources/org/mastodon/mamut/feature/dimensionalityreduction/tgmm-mini-spot.csv" );
25+
CSVParserBuilder parserBuilder = new CSVParserBuilder().withSeparator( ',' );
26+
CSVReaderBuilder builder = new CSVReaderBuilder( new FileReader( file ) ).withCSVParser( parserBuilder.build() );
27+
try (CSVReader reader = builder.build())
28+
{
29+
List< double[] > data = new ArrayList<>();
30+
reader.readNext(); // skip header
31+
for ( String[] nextLine; ( nextLine = reader.readNext() ) != null; )
32+
{
33+
double[] values = new double[ nextLine.length ];
34+
for ( int i = 0; i < nextLine.length; i++ )
35+
values[ i ] = Double.parseDouble( nextLine[ i ] );
36+
data.add( values );
37+
}
38+
double[][] inputData = data.toArray( new double[ data.size() ][ data.get( 0 ).length ] );
39+
StandardScaler.standardizeColumns( inputData );
40+
UMAP umap = setUpUmap( inputData );
41+
double[][] result = umap.coordinates;
42+
double[][] resultScaled = Arrays.stream( result ).map( row -> Arrays.stream( row ).map( value -> value * 10d ).toArray() )
43+
.toArray( double[][]::new );
44+
PlotPoints.plot( null, resultScaled, null );
45+
}
46+
}
47+
48+
static UMAP setUpUmap( double[][] data )
49+
{
50+
int iterations = data.length < 10_000 ? 500 : 200; // https://github.com/lmcinnes/umap/blob/a012b9d8751d98b94935ca21f278a54b3c3e1b7f/umap/umap_.py#L1073
51+
double minDist = 0.1;
52+
int nNeighbors = 15;
53+
return UMAP.of( data, nNeighbors, 2, iterations, 1, minDist, 1.0, 5, 1 );
54+
}
55+
}
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package org.mastodon.mamut.feature.dimensionalityreduction.umap;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertFalse;
5+
6+
import java.awt.geom.Rectangle2D;
7+
import java.util.Arrays;
8+
import java.util.Random;
9+
10+
import org.junit.Ignore;
11+
import org.junit.jupiter.api.Test;
12+
13+
import smile.manifold.TSNE;
14+
import smile.manifold.UMAP;
15+
16+
@Ignore( "This test is failing due to a bug in the smile UMAP implementation, cf.: https://github.com/haifengl/smile/pull/796 " )
17+
class UmapSmileTest
18+
{
19+
@Test
20+
void test()
21+
{
22+
int numCluster1 = 50;
23+
int numCluster2 = 100;
24+
// create two distinct clusters of points in 3D space, one having 50 points and the other 100 points
25+
double[][] sampleData = generateSampleData( numCluster1, numCluster2 );
26+
27+
// Recommendations for t-SNE defaults: https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html
28+
double perplexity = 30d; // recommended value is between 5 and 50
29+
int maxIterations = 1000; // should be at least 250
30+
31+
TSNE tsne = new TSNE( sampleData, 2, perplexity, 200, maxIterations );
32+
double[][] tsneResult = tsne.coordinates;
33+
34+
assertEquals( tsneResult.length, sampleData.length ); // passes
35+
assertEquals( 2, tsneResult[ 0 ].length ); // passes
36+
37+
double[][] tsneResult1 = Arrays.copyOfRange( tsneResult, 0, numCluster1 );
38+
double[][] tsneResult2 = Arrays.copyOfRange( tsneResult, numCluster1, numCluster1 + numCluster2 );
39+
40+
testNonOverlappingClusters( tsneResult1, tsneResult2 ); // passes
41+
42+
// Recommendations for UMAP defaults: https://github.com/lmcinnes/umap/blob/a012b9d8751d98b94935ca21f278a54b3c3e1b7f/umap/umap_.py#L1073
43+
int iterations = sampleData.length < 10_000 ? 500 : 200;
44+
double minDist = 0.1;
45+
int nNeighbors = 15;
46+
UMAP umap = UMAP.of( sampleData, nNeighbors, 2, iterations, 1, minDist, 1.0, 5, 1 );
47+
double[][] umapResult = umap.coordinates;
48+
49+
assertEquals( 2, umapResult[ 0 ].length );
50+
assertEquals( umapResult.length, sampleData.length ); // fails, because only the largest connected component is used inside the algorithm
51+
52+
double[][] umapResult1 = Arrays.copyOfRange( umapResult, 0, numCluster1 );
53+
double[][] umapResult2 = Arrays.copyOfRange( umapResult, numCluster1, numCluster1 + numCluster2 );
54+
55+
testNonOverlappingClusters( umapResult1, umapResult2 ); // should pass
56+
}
57+
58+
private static double[][] generateSampleData( int numCluster1, int numCluster2 )
59+
{
60+
double[][] firstPointCloud = generateRandomPointsInSphere( 100, 100, -10, 20, numCluster1 );
61+
double[][] secondPointCloud = generateRandomPointsInSphere( 250, 250, 10, 50, numCluster2 );
62+
63+
return concatenateArrays( firstPointCloud, secondPointCloud );
64+
}
65+
66+
private static double[][] concatenateArrays( final double[][] firstPointCloud, final double[][] secondPointCloud )
67+
{
68+
double[][] concatenated = new double[ firstPointCloud.length + secondPointCloud.length ][ 2 ];
69+
System.arraycopy( firstPointCloud, 0, concatenated, 0, firstPointCloud.length );
70+
System.arraycopy( secondPointCloud, 0, concatenated, firstPointCloud.length, secondPointCloud.length );
71+
return concatenated;
72+
}
73+
74+
private static double[][] generateRandomPointsInSphere( double centerX, double centerY, double centerZ, double radius,
75+
int numberOfPoints )
76+
{
77+
double[][] points = new double[ numberOfPoints ][ 3 ];
78+
79+
final Random random = new Random( 42 );
80+
81+
for ( int i = 0; i < numberOfPoints; i++ )
82+
{
83+
double r = radius * Math.cbrt( random.nextDouble() );
84+
double theta = 2 * Math.PI * random.nextDouble();
85+
double phi = Math.acos( 2 * random.nextDouble() - 1 );
86+
87+
double x = centerX + r * Math.sin( phi ) * Math.cos( theta );
88+
double y = centerY + r * Math.sin( phi ) * Math.sin( theta );
89+
double z = centerZ + r * Math.cos( phi );
90+
91+
points[ i ][ 0 ] = x;
92+
points[ i ][ 1 ] = y;
93+
points[ i ][ 2 ] = z;
94+
}
95+
96+
return points;
97+
}
98+
99+
private static void testNonOverlappingClusters( final double[][] cluster1, final double[][] cluster2 )
100+
{
101+
Rectangle2D.Double boundingBox1 = findBoundingBox( cluster1 );
102+
Rectangle2D.Double boundingBox2 = findBoundingBox( cluster2 );
103+
testNoPointInsideBoundingBox( cluster1, boundingBox2 );
104+
testNoPointInsideBoundingBox( cluster2, boundingBox1 );
105+
}
106+
107+
private static void testNoPointInsideBoundingBox( final double[][] points, final Rectangle2D.Double boundingBox )
108+
{
109+
for ( double[] point : points )
110+
{
111+
double x = point[ 0 ];
112+
double y = point[ 1 ];
113+
assertFalse( boundingBox.contains( x, y ) );
114+
}
115+
}
116+
117+
private static Rectangle2D.Double findBoundingBox( double[][] points )
118+
{
119+
double minX = Double.MAX_VALUE;
120+
double minY = Double.MAX_VALUE;
121+
double maxX = Double.MIN_VALUE;
122+
double maxY = Double.MIN_VALUE;
123+
124+
for ( double[] point : points )
125+
{
126+
double x = point[ 0 ];
127+
double y = point[ 1 ];
128+
129+
minX = Math.min( x, minX );
130+
maxX = Math.max( x, maxX );
131+
minY = Math.min( y, minY );
132+
maxY = Math.max( y, maxY );
133+
}
134+
135+
return new Rectangle2D.Double( minX, minY, maxX - minX, maxY - minY );
136+
}
137+
}

0 commit comments

Comments
 (0)