Skip to content

Commit 3c0111c

Browse files
Merge pull request #152 from mastodon-sc/stardist-pixi
Add StarDist GPU based on pixi
2 parents b1b1596 + 0e643bb commit 3c0111c

24 files changed

Lines changed: 1287 additions & 178 deletions

File tree

src/main/java/org/mastodon/mamut/detection/DeepLearningDetector.java

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
import net.imglib2.view.Views;
4444

4545
import org.apache.commons.lang3.tuple.Pair;
46-
import org.apposed.appose.Appose;
46+
import org.apposed.appose.Builder;
4747
import org.apposed.appose.Environment;
4848
import org.apposed.appose.Service;
4949
import org.mastodon.mamut.detection.util.SpimImageProperties;
@@ -130,7 +130,9 @@ public void compute( final List< SourceAndConverter< ? > > sources, final ModelG
130130
final Source< ? > source = sources.get( settings.getSetupId() ).getSpimSource();
131131
RandomAccessibleInterval< ? > image = source.getSource( 0, 0 );
132132

133-
Service.Task importTask = python.task( getImportScript( is2D( image ) ), "main" );
133+
String importScript = getImportScript( ImgUtils.is2D( image ) );
134+
logger.info( "import script:\n{}", importScript );
135+
Service.Task importTask = python.task( importScript, "main" );
134136
importTask.waitFor();
135137
this.pythonService = python;
136138
for ( int timepoint = minTimepoint; timepoint <= maxTimepoint; timepoint++ )
@@ -187,7 +189,7 @@ private Environment prepareEnvironment() throws IOException
187189
*/
188190
private Environment buildEnvironment() throws IOException
189191
{
190-
return Appose.mamba().scheme( "environment.yml" ).content( getPythonEnvContent() ).logDebug()
192+
return getBuilder().content( getPythonEnvContent() ).logDebug()
191193
.subscribeProgress( ( title, cur, max ) -> logger.info( "{}: {}/{}", title, cur, max ) )
192194
.subscribeOutput( logger::info )
193195
.subscribeError( logger::error ).build();
@@ -199,7 +201,8 @@ private SpimImageProperties extractSettings( final List< SourceAndConverter< ? >
199201
final int minTimepoint = ( int ) settings.get( DetectorKeys.KEY_MIN_TIMEPOINT );
200202
final int maxTimepoint = ( int ) settings.get( DetectorKeys.KEY_MAX_TIMEPOINT );
201203
final int setup = ( int ) settings.get( DetectorKeys.KEY_SETUP_ID );
202-
final int level = ( int ) settings.get( DeepLearningDetectorKeys.KEY_LEVEL );
204+
final Object levelObject = settings.get( DeepLearningDetectorKeys.KEY_LEVEL );
205+
final int level = levelObject == null ? 0 : ( int ) settings.get( DeepLearningDetectorKeys.KEY_LEVEL );
203206

204207
logger.info( "Settings contain, minTimepoint: {}, maxTimepoint: {}, setup: {} and level: {}", minTimepoint, maxTimepoint, setup,
205208
level );
@@ -267,9 +270,8 @@ private void detectAndAddSpots( final List< SourceAndConverter< ? > > sources, f
267270
image = Views.interval( image, roi );
268271
}
269272

270-
final Img< ? > segmentation =
271-
performSegmentation( Views.dropSingletonDimensions( image ), source.getVoxelDimensions().dimensionsAsDoubleArray(),
272-
python );
273+
final Img< ? > segmentation = performSegmentation( Views.dropSingletonDimensions( image ),
274+
source.getVoxelDimensions().dimensionsAsDoubleArray(), python );
273275

274276
if ( segmentation != null )
275277
{
@@ -326,28 +328,6 @@ protected double getAnisotropy( double[] voxelSizes, boolean is3D )
326328
return highestValue / lowestValue;
327329
}
328330

329-
protected boolean is2D( final RandomAccessibleInterval< ? > image )
330-
{
331-
return !is3D( image );
332-
}
333-
334-
protected boolean is3D( final RandomAccessibleInterval< ? > image )
335-
{
336-
long[] dimensions = image.dimensionsAsLongArray();
337-
if ( dimensions.length <= 2 )
338-
return false;
339-
else
340-
{
341-
int nonPlaneDimensionCount = 0;
342-
for ( final long dimension : dimensions )
343-
{
344-
if ( dimension > 1 )
345-
nonPlaneDimensionCount++;
346-
}
347-
return nonPlaneDimensionCount > 2;
348-
}
349-
}
350-
351331
@Override
352332
public Map< String, Object > getDefaultSettings()
353333
{
@@ -374,6 +354,8 @@ public Map< String, Object > getDefaultSettings()
374354

375355
protected abstract String getPythonEnvName();
376356

357+
protected abstract Builder< ? > getBuilder();
358+
377359
protected String getPythonEnvInit()
378360
{
379361
return "import numpy\n";

src/main/java/org/mastodon/mamut/detection/Segmentation.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
import java.util.Arrays;
3434
import java.util.stream.Collectors;
3535

36+
import javax.annotation.Nullable;
37+
3638
import net.imglib2.RandomAccessibleInterval;
3739
import net.imglib2.appose.NDArrays;
3840
import net.imglib2.appose.ShmImg;
@@ -54,9 +56,9 @@ public abstract class Segmentation extends ApposeProcess
5456
{
5557
private static final Logger logger = LoggerFactory.getLogger( MethodHandles.lookup().lookupClass() );
5658

57-
protected Segmentation( final Service pythonService )
59+
protected Segmentation( final Service pythonService, final @Nullable org.scijava.log.Logger scijavaLogger )
5860
{
59-
super( pythonService );
61+
super( pythonService, scijavaLogger );
6062
}
6163

6264
/**

src/main/java/org/mastodon/mamut/detection/cellpose/Cellpose.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
import java.io.IOException;
3232

33+
import javax.annotation.Nullable;
34+
3335
import org.apposed.appose.Service;
3436
import org.mastodon.mamut.detection.Segmentation;
3537

@@ -59,9 +61,9 @@ public abstract class Cellpose extends Segmentation
5961

6062
public static final double DEFAULT_DIAMETER = 0d;
6163

62-
protected Cellpose( final Service python ) throws IOException
64+
protected Cellpose( final Service python, final @Nullable org.scijava.log.Logger scijavaLogger ) throws IOException
6365
{
64-
super( python );
66+
super( python, scijavaLogger );
6567
}
6668

6769
public void setCellProbThreshold( final double cellProbThreshold )

src/main/java/org/mastodon/mamut/detection/cellpose/Cellpose3.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
import java.io.IOException;
3232

33+
import javax.annotation.Nullable;
34+
3335
import org.apposed.appose.Service;
3436

3537
/**
@@ -61,9 +63,10 @@ public class Cellpose3 extends Cellpose
6163

6264
private double anisotropy = 1;
6365

64-
public Cellpose3( final ModelType modelType, final Service python ) throws IOException
66+
public Cellpose3( final ModelType modelType, final Service python, final @Nullable org.scijava.log.Logger scijavaLogger )
67+
throws IOException
6568
{
66-
super( python );
69+
super( python, scijavaLogger );
6770
this.modelType = modelType;
6871
}
6972

src/main/java/org/mastodon/mamut/detection/cellpose/Cellpose3Detector.java

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,13 @@
5252
import net.imglib2.img.Img;
5353
import net.imglib2.util.Cast;
5454

55+
import org.apposed.appose.Appose;
56+
import org.apposed.appose.Builder;
5557
import org.apposed.appose.Service;
5658
import org.mastodon.mamut.detection.DeepLearningDetector;
57-
import org.mastodon.mamut.detection.stardist.StarDist;
59+
import org.mastodon.mamut.util.ImgUtils;
5860
import org.mastodon.tracking.mamut.detection.SpotDetectorOp;
61+
import org.mastodon.tracking.mamut.trackmate.wizard.descriptors.cellpose.Cellpose4DetectorDescriptor;
5962
import org.scijava.Priority;
6063
import org.scijava.plugin.Plugin;
6164
import org.slf4j.Logger;
@@ -106,12 +109,26 @@ && checkParameter( settings, KEY_GPU_ID, Integer.class, errorHolder )
106109

107110
try
108111
{
109-
Cellpose3 cellpose = new Cellpose3( ( Cellpose3.ModelType ) settings.get( KEY_MODEL_TYPE ), python );
110-
boolean is3D = is3D( image );
112+
Cellpose3 cellpose = new Cellpose3( ( Cellpose3.ModelType ) settings.get( KEY_MODEL_TYPE ), python, log );
113+
boolean is3D = ImgUtils.is3D( image );
111114
cellpose.set3D( is3D );
112115
cellpose.setCellProbThreshold( ( double ) settings.get( KEY_CELL_PROBABILITY_THRESHOLD ) );
113116
cellpose.setFlowThreshold( ( double ) settings.get( KEY_FLOW_THRESHOLD ) );
114-
cellpose.setDiameter( ( double ) settings.get( KEY_DIAMETER ) );
117+
Object diameterObject = settings.get( Cellpose4DetectorDescriptor.KEY_DIAMETER );
118+
if ( diameterObject != null )
119+
{
120+
double diameter = ( double ) diameterObject;
121+
int level = ( int ) settings.get( KEY_LEVEL );
122+
if ( level != 0 )
123+
{
124+
// Adjust diameter based on the pyramid level
125+
diameter = diameter / Math.pow( 2, level );
126+
logger.info( "Adjusted diameter for pyramid level {}: {}", level, diameter );
127+
}
128+
cellpose.setDiameter( diameter );
129+
}
130+
else
131+
cellpose.setDiameter( 0 );
115132
cellpose.setGpuID( ( int ) settings.get( KEY_GPU_ID ) );
116133
cellpose.setGpuMemoryFraction( ( double ) settings.get( KEY_GPU_MEMORY_FRACTION ) );
117134
final boolean respectAnisotropy = ( boolean ) settings.get( KEY_RESPECT_ANISOTROPY );
@@ -159,9 +176,21 @@ protected String getPythonEnvName()
159176
return Cellpose3.ENV_NAME;
160177
}
161178

179+
@Override
180+
protected Builder< ? > getBuilder()
181+
{
182+
return Appose.mamba().scheme( "environment.yml" );
183+
}
184+
162185
@Override
163186
protected String getImportScript( final boolean dataIs2D )
164187
{
165-
return Cellpose3.generateImportStatements();
188+
return Cellpose.generateImportStatements();
189+
}
190+
191+
@Override
192+
protected String getPythonEnvInit()
193+
{
194+
return "import numpy\nfrom cellpose import models\n"; // NB: StarDist2D import needs to be inited even for 3D cases
166195
}
167196
}

src/main/java/org/mastodon/mamut/detection/cellpose/Cellpose4.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
import java.io.IOException;
3232

33+
import javax.annotation.Nullable;
34+
3335
import org.apposed.appose.Service;
3436

3537
/**
@@ -56,9 +58,9 @@ public class Cellpose4 extends Cellpose
5658
+ " - pip:\n"
5759
+ " - appose==" + APPOSE_PYTHON_VERSION + "\n";
5860

59-
public Cellpose4( final Service python ) throws IOException
61+
public Cellpose4( final Service python, final @Nullable org.scijava.log.Logger scijavaLogger ) throws IOException
6062
{
61-
super( python );
63+
super( python, scijavaLogger );
6264
}
6365

6466
@Override

src/main/java/org/mastodon/mamut/detection/cellpose/Cellpose4Detector.java

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,11 @@
4848
import net.imglib2.img.Img;
4949
import net.imglib2.util.Cast;
5050

51+
import org.apposed.appose.Appose;
52+
import org.apposed.appose.Builder;
5153
import org.apposed.appose.Service;
5254
import org.mastodon.mamut.detection.DeepLearningDetector;
53-
import org.mastodon.mamut.detection.stardist.StarDist;
55+
import org.mastodon.mamut.util.ImgUtils;
5456
import org.mastodon.tracking.mamut.detection.SpotDetectorOp;
5557
import org.mastodon.tracking.mamut.trackmate.wizard.descriptors.cellpose.Cellpose4DetectorDescriptor;
5658
import org.scijava.Priority;
@@ -97,11 +99,25 @@ && checkParameter( settings, KEY_GPU_ID, Integer.class, errorHolder )
9799
{
98100
try
99101
{
100-
Cellpose4 cellpose = new Cellpose4( python );
101-
cellpose.set3D( is3D( image ) );
102+
Cellpose4 cellpose = new Cellpose4( python, log );
103+
cellpose.set3D( ImgUtils.is3D( image ) );
102104
cellpose.setCellProbThreshold( ( double ) settings.get( KEY_CELL_PROBABILITY_THRESHOLD ) );
103105
cellpose.setFlowThreshold( ( double ) settings.get( KEY_FLOW_THRESHOLD ) );
104-
cellpose.setDiameter( ( double ) settings.get( KEY_DIAMETER ) );
106+
Object diameterObject = settings.get( Cellpose4DetectorDescriptor.KEY_DIAMETER );
107+
if ( diameterObject != null )
108+
{
109+
double diameter = ( double ) diameterObject;
110+
int level = ( int ) settings.get( KEY_LEVEL );
111+
if ( level != 0 )
112+
{
113+
// Adjust diameter based on the pyramid level
114+
diameter = diameter / Math.pow( 2, level );
115+
logger.info( "Adjusted diameter for pyramid level {}: {}", level, diameter );
116+
}
117+
cellpose.setDiameter( diameter );
118+
}
119+
else
120+
cellpose.setDiameter( 0 );
105121
cellpose.setGpuID( ( int ) settings.get( KEY_GPU_ID ) );
106122
cellpose.setGpuMemoryFraction( ( double ) settings.get( KEY_GPU_MEMORY_FRACTION ) );
107123
return cellpose.segmentImage( Cast.unchecked( image ) );
@@ -144,9 +160,21 @@ protected String getPythonEnvName()
144160
return Cellpose4.ENV_NAME;
145161
}
146162

163+
@Override
164+
protected Builder< ? > getBuilder()
165+
{
166+
return Appose.mamba().scheme( "environment.yml" );
167+
}
168+
147169
@Override
148170
protected String getImportScript( final boolean dataIs2D )
149171
{
150172
return Cellpose.generateImportStatements();
151173
}
174+
175+
@Override
176+
protected String getPythonEnvInit()
177+
{
178+
return "import numpy\nfrom cellpose import models\n"; // NB: StarDist2D import needs to be inited even for 3D cases
179+
}
152180
}

0 commit comments

Comments
 (0)