Skip to content

Commit fe6b302

Browse files
author
Stefan Hahmann
committed
Add stardist model selection
1 parent a9a942f commit fe6b302

5 files changed

Lines changed: 236 additions & 106 deletions

File tree

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
package org.mastodon.mamut.detection;
2+
3+
import java.io.File;
4+
import java.io.IOException;
5+
import java.lang.invoke.MethodHandles;
6+
import java.nio.file.Path;
7+
import java.nio.file.Paths;
8+
9+
import org.slf4j.Logger;
10+
import org.slf4j.LoggerFactory;
11+
12+
import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
13+
14+
public class StarDist extends Segmentation3D
15+
{
16+
private static final Logger logger = LoggerFactory.getLogger( MethodHandles.lookup().lookupClass() );
17+
18+
private final ModelType modelType;
19+
20+
private boolean dataIs2D;
21+
22+
private static final String PATH = "stardist";
23+
24+
public StarDist( final ModelType model ) throws IOException, InterruptedException
25+
{
26+
super();
27+
this.modelType = model;
28+
if ( modelType == null )
29+
logger.info( "No star dist model path specified. Using pretrained demo model." );
30+
else
31+
{
32+
Path starDistModelRoot =
33+
Paths.get( System.getProperty( "user.home" ), ".local", "share", "stardist", "models", modelType.getModelPath() );
34+
File directory = starDistModelRoot.toFile();
35+
if ( directory.isDirectory() )
36+
{
37+
String[] files = directory.list();
38+
if ( files != null && files.length > 0 )
39+
logger.debug( "Found {} files in {}", files.length, directory.getAbsolutePath() );
40+
else
41+
BioimageioRepo.connect().downloadByName( "StarDist Plant Nuclei 3D ResNet", directory.getAbsolutePath() );
42+
}
43+
else
44+
logger.error( "The specified path is not a directory: {}", directory.getAbsolutePath() );
45+
}
46+
}
47+
48+
@Override
49+
String generateEnvFileContent()
50+
{
51+
return "name: stardist\n"
52+
+ "channels:\n"
53+
+ " - conda-forge\n"
54+
+ "dependencies:\n"
55+
+ " - python=3.10\n"
56+
+ " - cudatoolkit=11.2\n"
57+
+ " - cudnn=8.1.0\n"
58+
+ " - numpy<1.24\n"
59+
+ " - pip\n"
60+
+ " - pip:\n"
61+
+ " - numpy<1.24\n"
62+
+ " - tensorflow==2.10\n"
63+
+ " - stardist==0.8.5\n"
64+
+ " - appose\n";
65+
}
66+
67+
@Override
68+
String generateScript()
69+
{
70+
return "import numpy as np" + "\n"
71+
+ "import appose" + "\n"
72+
+ "from csbdeep.utils import normalize" + "\n"
73+
+ getImportStarDistCommand()
74+
+ "\n"
75+
+ "task.update(message=\"Imports completed\")" + "\n"
76+
+ "np.random.seed(6)" + "\n"
77+
+ "axes_normalize = (0, 1, 2)" + "\n"
78+
+ "\n"
79+
+ "task.update(message=\"Loading StarDist pretrained 3D model\")" + "\n"
80+
+ getLoadModelCommand()
81+
+ "image_ndarray = image.ndarray()" + "\n"
82+
+ "image_normalized = normalize(image_ndarray, 1, 99.8, axis=axes_normalize)" + "\n"
83+
+ "task.update(message=\"Image shape:\" + str(image_normalized.shape))" + "\n"
84+
+ "\n"
85+
+ "guessed_tiles = model._guess_n_tiles(image_normalized)" + "\n"
86+
+ "task.update(message=\"Guessed tiles: \" + str(guessed_tiles))" + "\n"
87+
+ "\n"
88+
+ "label_image, details = model.predict_instances(image_normalized, axes='ZYX', n_tiles=guessed_tiles)" + "\n"
89+
+ "shared = appose.NDArray(image.dtype, image.shape)" + "\n"
90+
+ "shared.ndarray()[:] = label_image" + "\n"
91+
+ "task.outputs['label_image'] = shared" + "\n";
92+
}
93+
94+
public ModelType getModelType()
95+
{
96+
return modelType;
97+
}
98+
99+
private String getImportStarDistCommand()
100+
{
101+
if ( modelType == null )
102+
{
103+
if ( dataIs2D )
104+
return "from stardist.models import StarDist2D" + "\n ";
105+
return "from stardist.models import StarDist3D" + "\n ";
106+
}
107+
if ( modelType.is2D() )
108+
return "from stardist.models import StarDist2D" + "\n ";
109+
return "from stardist.models import StarDist3D" + "\n ";
110+
}
111+
112+
private String getLoadModelCommand()
113+
{
114+
if ( modelType == null )
115+
{
116+
if ( dataIs2D )
117+
return "model = StarDist2D.from_pretrained('2D_demo')" + "\n";
118+
else
119+
return "model = StarDist3D.from_pretrained('3D_demo')" + "\n";
120+
}
121+
String starDistModel = modelType.is2D() ? "StarDist2D" : "StarDist3D";
122+
return "model = " + starDistModel + "(None, name='" + modelType.getModelPath() + "', basedir=r\"" + PATH + "\")" + "\n";
123+
}
124+
125+
public enum ModelType
126+
{
127+
PLANT_NUCLEI_3D( "StarDist Plant Nuclei 3D ResNet", "stardist-plant-nuclei-3d", false ),
128+
FLUO_2D( "StarDist Fluorescence Nuclei Segmentation", "stardist-fluo-2d", true ),
129+
H_E( "StarDist H&E Nuclei Segmentation", "stardist-h-e-nuclei", true ),
130+
DEMO_2D( "StarDist Demo", "stardist-demo-2d", true ),
131+
DEMO_3D( "StarDist Demo", "stardist-demo-3d", false );
132+
133+
private final String modelName;
134+
135+
private final String modelPath;
136+
137+
private final boolean is2D;
138+
139+
ModelType( final String modelName, final String modelPath, final boolean is2D )
140+
{
141+
this.modelName = modelName;
142+
this.modelPath = modelPath;
143+
this.is2D = is2D;
144+
}
145+
146+
public String getModelName()
147+
{
148+
return modelName;
149+
}
150+
151+
public String getModelPath()
152+
{
153+
return modelPath;
154+
}
155+
156+
public boolean is2D()
157+
{
158+
return is2D;
159+
}
160+
161+
@Override
162+
public String toString()
163+
{
164+
String dimensionality = is2D ? " (2D)" : " (3D)";
165+
return modelName + dimensionality;
166+
}
167+
168+
public static ModelType fromString( final String modelName )
169+
{
170+
for ( ModelType type : ModelType.values() )
171+
{
172+
if ( type.modelName.equalsIgnoreCase( modelName ) )
173+
{
174+
return type;
175+
}
176+
}
177+
throw new IllegalArgumentException( "No enum constant for model name: " + modelName );
178+
}
179+
}
180+
}

src/main/java/org/mastodon/mamut/detection/StarDist3D.java

Lines changed: 0 additions & 92 deletions
This file was deleted.

src/main/java/org/mastodon/mamut/detection/StarDist3DDetector.java renamed to src/main/java/org/mastodon/mamut/detection/StarDistDetector.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import static org.mastodon.tracking.detection.DetectorKeys.KEY_MAX_TIMEPOINT;
77
import static org.mastodon.tracking.detection.DetectorKeys.KEY_MIN_TIMEPOINT;
88
import static org.mastodon.tracking.detection.DetectorKeys.KEY_SETUP_ID;
9+
import static org.mastodon.tracking.mamut.trackmate.wizard.descriptors.StarDistDetectorDescriptor.KEY_MODEL_TYPE;
910
import static org.mastodon.tracking.linking.LinkingUtils.checkParameter;
1011

1112
import java.lang.invoke.MethodHandles;
@@ -39,7 +40,7 @@
3940
+ "A cell probability threshold can be set to allow more/less sensitive detection.<br><br>"
4041
+ "<strong>When this detection method is used for the first time, internet connection is needed, since an internal installation process is started. The installation consumes ~7GB hard disk space.</strong><br>"
4142
+ "</html>" )
42-
public class StarDist3DDetector extends AbstractSpotDetectorOp
43+
public class StarDistDetector extends AbstractSpotDetectorOp
4344
{
4445

4546
private static final Logger logger = LoggerFactory.getLogger( MethodHandles.lookup().lookupClass() );
@@ -55,7 +56,8 @@ public void compute( final List< SourceAndConverter< ? > > sources, final ModelG
5556
// A. Read the settings map, and check the validity.
5657
final StringBuilder errorHolder = new StringBuilder();
5758
boolean good;
58-
good = checkParameter( settings, KEY_SETUP_ID, Integer.class, errorHolder );
59+
good = checkParameter( settings, KEY_MODEL_TYPE, StarDist.ModelType.class, errorHolder );
60+
good = good & checkParameter( settings, KEY_SETUP_ID, Integer.class, errorHolder );
5961
good = good & checkParameter( settings, KEY_MIN_TIMEPOINT, Integer.class, errorHolder );
6062
good = good & checkParameter( settings, KEY_MAX_TIMEPOINT, Integer.class, errorHolder );
6163
if ( !good )
@@ -68,6 +70,7 @@ public void compute( final List< SourceAndConverter< ? > > sources, final ModelG
6870
final int minTimepoint = ( int ) settings.get( KEY_MIN_TIMEPOINT );
6971
final int maxTimepoint = ( int ) settings.get( KEY_MAX_TIMEPOINT );
7072
final int setup = ( int ) settings.get( KEY_SETUP_ID );
73+
final StarDist.ModelType modelType = ( StarDist.ModelType ) settings.get( KEY_MODEL_TYPE );
7174

7275
if ( setup < 0 || setup >= sources.size() )
7376
{
@@ -86,7 +89,7 @@ public void compute( final List< SourceAndConverter< ? > > sources, final ModelG
8689
// The `statusService` can be used to show short messages.
8790
statusService.showStatus( "Detecting spots using Cellpose." );
8891

89-
try (StarDist3D starDist3D = new StarDist3D( null ))
92+
try (StarDist starDist = new StarDist( modelType ))
9093
{
9194
for ( int timepoint = minTimepoint; timepoint <= maxTimepoint; timepoint++ )
9295
{
@@ -115,13 +118,14 @@ public void compute( final List< SourceAndConverter< ? > > sources, final ModelG
115118
*/
116119
final int level = 0;
117120
final RandomAccessibleInterval< ? > image = source.getSource( timepoint, level );
121+
boolean is3D = is3D( image );
118122

119123
/*
120124
* This is the 3D image of the current time-point, specified
121125
* channel. It is always 3D. If the source is 2D, the 3rd dimension
122126
* will have a size of 1.
123127
*/
124-
Img< ? > segmentation = starDist3D.segmentImage( Cast.unchecked( image ) );
128+
Img< ? > segmentation = starDist.segmentImage( Cast.unchecked( image ) );
125129

126130
final AffineTransform3D transform = DetectionUtil.getTransform( sources, timepoint, setup, level );
127131
LabelImageUtils.createSpotsForFrame( graph, Cast.unchecked( segmentation ), timepoint, transform, 1d );

0 commit comments

Comments
 (0)