|
| 1 | +package org.mastodon.mamut.feature.dimensionalityreduction.umap; |
| 2 | + |
| 3 | +import static org.junit.jupiter.api.Assertions.assertEquals; |
| 4 | +import static org.junit.jupiter.api.Assertions.assertFalse; |
| 5 | + |
| 6 | +import java.awt.geom.Rectangle2D; |
| 7 | +import java.util.Arrays; |
| 8 | +import java.util.Random; |
| 9 | + |
| 10 | +import org.junit.Ignore; |
| 11 | +import org.junit.jupiter.api.Test; |
| 12 | + |
| 13 | +import smile.manifold.TSNE; |
| 14 | +import smile.manifold.UMAP; |
| 15 | + |
| 16 | +@Ignore( "This test is failing due to a bug in the smile UMAP implementation, cf.: https://github.com/haifengl/smile/pull/796 " ) |
| 17 | +class UmapSmileTest |
| 18 | +{ |
| 19 | + @Test |
| 20 | + void test() |
| 21 | + { |
| 22 | + int numCluster1 = 50; |
| 23 | + int numCluster2 = 100; |
| 24 | + // create two distinct clusters of points in 3D space, one having 50 points and the other 100 points |
| 25 | + double[][] sampleData = generateSampleData( numCluster1, numCluster2 ); |
| 26 | + |
| 27 | + // Recommendations for t-SNE defaults: https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html |
| 28 | + double perplexity = 30d; // recommended value is between 5 and 50 |
| 29 | + int maxIterations = 1000; // should be at least 250 |
| 30 | + |
| 31 | + TSNE tsne = new TSNE( sampleData, 2, perplexity, 200, maxIterations ); |
| 32 | + double[][] tsneResult = tsne.coordinates; |
| 33 | + |
| 34 | + assertEquals( tsneResult.length, sampleData.length ); // passes |
| 35 | + assertEquals( 2, tsneResult[ 0 ].length ); // passes |
| 36 | + |
| 37 | + double[][] tsneResult1 = Arrays.copyOfRange( tsneResult, 0, numCluster1 ); |
| 38 | + double[][] tsneResult2 = Arrays.copyOfRange( tsneResult, numCluster1, numCluster1 + numCluster2 ); |
| 39 | + |
| 40 | + testNonOverlappingClusters( tsneResult1, tsneResult2 ); // passes |
| 41 | + |
| 42 | + // Recommendations for UMAP defaults: https://github.com/lmcinnes/umap/blob/a012b9d8751d98b94935ca21f278a54b3c3e1b7f/umap/umap_.py#L1073 |
| 43 | + int iterations = sampleData.length < 10_000 ? 500 : 200; |
| 44 | + double minDist = 0.1; |
| 45 | + int nNeighbors = 15; |
| 46 | + UMAP umap = UMAP.of( sampleData, nNeighbors, 2, iterations, 1, minDist, 1.0, 5, 1 ); |
| 47 | + double[][] umapResult = umap.coordinates; |
| 48 | + |
| 49 | + assertEquals( 2, umapResult[ 0 ].length ); |
| 50 | + assertEquals( umapResult.length, sampleData.length ); // fails, because only the largest connected component is used inside the algorithm |
| 51 | + |
| 52 | + double[][] umapResult1 = Arrays.copyOfRange( umapResult, 0, numCluster1 ); |
| 53 | + double[][] umapResult2 = Arrays.copyOfRange( umapResult, numCluster1, numCluster1 + numCluster2 ); |
| 54 | + |
| 55 | + testNonOverlappingClusters( umapResult1, umapResult2 ); // should pass |
| 56 | + } |
| 57 | + |
| 58 | + private static double[][] generateSampleData( int numCluster1, int numCluster2 ) |
| 59 | + { |
| 60 | + double[][] firstPointCloud = generateRandomPointsInSphere( 100, 100, -10, 20, numCluster1 ); |
| 61 | + double[][] secondPointCloud = generateRandomPointsInSphere( 250, 250, 10, 50, numCluster2 ); |
| 62 | + |
| 63 | + return concatenateArrays( firstPointCloud, secondPointCloud ); |
| 64 | + } |
| 65 | + |
| 66 | + private static double[][] concatenateArrays( final double[][] firstPointCloud, final double[][] secondPointCloud ) |
| 67 | + { |
| 68 | + double[][] concatenated = new double[ firstPointCloud.length + secondPointCloud.length ][ 2 ]; |
| 69 | + System.arraycopy( firstPointCloud, 0, concatenated, 0, firstPointCloud.length ); |
| 70 | + System.arraycopy( secondPointCloud, 0, concatenated, firstPointCloud.length, secondPointCloud.length ); |
| 71 | + return concatenated; |
| 72 | + } |
| 73 | + |
| 74 | + private static double[][] generateRandomPointsInSphere( double centerX, double centerY, double centerZ, double radius, |
| 75 | + int numberOfPoints ) |
| 76 | + { |
| 77 | + double[][] points = new double[ numberOfPoints ][ 3 ]; |
| 78 | + |
| 79 | + final Random random = new Random( 42 ); |
| 80 | + |
| 81 | + for ( int i = 0; i < numberOfPoints; i++ ) |
| 82 | + { |
| 83 | + double r = radius * Math.cbrt( random.nextDouble() ); |
| 84 | + double theta = 2 * Math.PI * random.nextDouble(); |
| 85 | + double phi = Math.acos( 2 * random.nextDouble() - 1 ); |
| 86 | + |
| 87 | + double x = centerX + r * Math.sin( phi ) * Math.cos( theta ); |
| 88 | + double y = centerY + r * Math.sin( phi ) * Math.sin( theta ); |
| 89 | + double z = centerZ + r * Math.cos( phi ); |
| 90 | + |
| 91 | + points[ i ][ 0 ] = x; |
| 92 | + points[ i ][ 1 ] = y; |
| 93 | + points[ i ][ 2 ] = z; |
| 94 | + } |
| 95 | + |
| 96 | + return points; |
| 97 | + } |
| 98 | + |
| 99 | + private static void testNonOverlappingClusters( final double[][] cluster1, final double[][] cluster2 ) |
| 100 | + { |
| 101 | + Rectangle2D.Double boundingBox1 = findBoundingBox( cluster1 ); |
| 102 | + Rectangle2D.Double boundingBox2 = findBoundingBox( cluster2 ); |
| 103 | + testNoPointInsideBoundingBox( cluster1, boundingBox2 ); |
| 104 | + testNoPointInsideBoundingBox( cluster2, boundingBox1 ); |
| 105 | + } |
| 106 | + |
| 107 | + private static void testNoPointInsideBoundingBox( final double[][] points, final Rectangle2D.Double boundingBox ) |
| 108 | + { |
| 109 | + for ( double[] point : points ) |
| 110 | + { |
| 111 | + double x = point[ 0 ]; |
| 112 | + double y = point[ 1 ]; |
| 113 | + assertFalse( boundingBox.contains( x, y ) ); |
| 114 | + } |
| 115 | + } |
| 116 | + |
| 117 | + private static Rectangle2D.Double findBoundingBox( double[][] points ) |
| 118 | + { |
| 119 | + double minX = Double.MAX_VALUE; |
| 120 | + double minY = Double.MAX_VALUE; |
| 121 | + double maxX = Double.MIN_VALUE; |
| 122 | + double maxY = Double.MIN_VALUE; |
| 123 | + |
| 124 | + for ( double[] point : points ) |
| 125 | + { |
| 126 | + double x = point[ 0 ]; |
| 127 | + double y = point[ 1 ]; |
| 128 | + |
| 129 | + minX = Math.min( x, minX ); |
| 130 | + maxX = Math.max( x, maxX ); |
| 131 | + minY = Math.min( y, minY ); |
| 132 | + maxY = Math.max( y, maxY ); |
| 133 | + } |
| 134 | + |
| 135 | + return new Rectangle2D.Double( minX, minY, maxX - minX, maxY - minY ); |
| 136 | + } |
| 137 | +} |
0 commit comments