Skip to content

Commit 5975827

Browse files
author
Stefan Hahmann
committed
Add StarDist3D detector
1 parent 7e41818 commit 5975827

6 files changed

Lines changed: 683 additions & 60 deletions

File tree

pom.xml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,20 @@
150150
<version>0.4.0-SNAPSHOT</version>
151151
</dependency>
152152

153-
<!-- mastodon tracking as base for python based detection -->
153+
<!-- mastodon tracking as base for Python-based detection -->
154154
<dependency>
155155
<groupId>${mastodon.group}</groupId>
156156
<artifactId>mastodon-tracking</artifactId>
157157
<version>${mastodon-tracking.version}</version>
158158
</dependency>
159159

160+
<!-- dl-modelrunner -->
161+
<dependency>
162+
<groupId>io.bioimage</groupId>
163+
<artifactId>dl-modelrunner</artifactId>
164+
<version>0.5.10</version>
165+
</dependency>
166+
160167
<!-- Test dependencies -->
161168
<dependency>
162169
<groupId>org.junit.jupiter</groupId>
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
package org.mastodon.mamut.detection;
2+
3+
import java.io.File;
4+
import java.io.IOException;
5+
import java.lang.invoke.MethodHandles;
6+
import java.nio.file.Path;
7+
import java.nio.file.Paths;
8+
import java.util.Map;
9+
10+
import net.imglib2.util.Cast;
11+
12+
import org.slf4j.Logger;
13+
import org.slf4j.LoggerFactory;
14+
15+
import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
16+
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
17+
import io.bioimage.modelrunner.bioimageio.description.exceptions.ModelSpecsException;
18+
import io.bioimage.modelrunner.utils.JSONUtils;
19+
20+
public class StarDist extends Segmentation3D
21+
{
22+
private static final Logger logger = LoggerFactory.getLogger( MethodHandles.lookup().lookupClass() );
23+
24+
private final ModelType modelType;
25+
26+
private String installationFolderName;
27+
28+
private boolean dataIs2D;
29+
30+
public StarDist( final ModelType model ) throws IOException, InterruptedException
31+
{
32+
super();
33+
logger.info( "Initializing StarDist, model: {}", model );
34+
this.modelType = model;
35+
Path starDistModelRoot = getStarDistModelRoot();
36+
if ( starDistModelRoot == null )
37+
logger.debug( "StarDist model path is null. This is normal for the built-in demo models" );
38+
else
39+
{
40+
File directory = starDistModelRoot.toFile();
41+
if ( !directory.exists() )
42+
{
43+
if ( !directory.mkdirs() )
44+
{
45+
throw new RuntimeException( "Failed to create environment directory: " + directory );
46+
}
47+
}
48+
if ( directory.isDirectory() )
49+
{
50+
File[] files = directory.listFiles();
51+
if ( files != null && files.length > 0 )
52+
{
53+
logger.info( "Found {} files/directories in {}", files.length, directory.getAbsolutePath() );
54+
for ( File file : files )
55+
logger.debug( "File/Directory: {}", file.getName() );
56+
installationFolderName = Paths.get( files[ 0 ].getAbsolutePath() ).getFileName().toString();
57+
logger.info( "Reusing model in folder: {}", installationFolderName );
58+
}
59+
else
60+
{
61+
try
62+
{
63+
logger.info( "Downloading model to {}", directory.getAbsolutePath() );
64+
BioimageioRepo repo = BioimageioRepo.connect();
65+
ModelDescriptor descriptor = repo.selectByName( modelType.getModelName() );
66+
String installationFolder = repo.downloadByName( modelType.getModelName(), directory.getAbsolutePath() );
67+
installationFolderName = Paths.get( installationFolder ).getFileName().toString();
68+
createConfigFromBioimageio( descriptor, directory.getAbsolutePath() + File.separator + installationFolderName );
69+
logger.info( "Downloading finished. Installation folder: {}", installationFolderName );
70+
}
71+
catch ( IllegalArgumentException e )
72+
{
73+
logger.info( "Exception while downloading model: {}", e.getMessage() );
74+
}
75+
catch ( ModelSpecsException e )
76+
{
77+
logger.info( "Exception while creating config.json: {}", e.getMessage() );
78+
}
79+
}
80+
}
81+
else
82+
logger.error( "The specified path is not a directory: {}", directory.getAbsolutePath() );
83+
}
84+
}
85+
86+
private Path getStarDistModelRoot()
87+
{
88+
if ( modelType.getModelPath() == null )
89+
return null;
90+
return Paths.get( System.getProperty( "user.home" ), ".local", "share", "appose", "stardist", "models", modelType.getModelPath() );
91+
}
92+
93+
private static void createConfigFromBioimageio( final ModelDescriptor descriptor, final String modelDir )
94+
throws IOException, ModelSpecsException
95+
{
96+
Map< String, Object > stardistMap = Cast.unchecked( descriptor.getConfig().getSpecMap().get( "stardist" ) );
97+
Map< String, Object > stardistConfig = Cast.unchecked( stardistMap.get( "config" ) );
98+
File jsonFile = new File( modelDir, "config.json" );
99+
logger.info( "Creating config.json file: {}", jsonFile.getAbsolutePath() );
100+
JSONUtils.writeJSONFile( jsonFile.getAbsolutePath(), stardistConfig );
101+
}
102+
103+
@Override
104+
String generateEnvFileContent()
105+
{
106+
return "name: stardist\n"
107+
+ "channels:\n"
108+
+ " - conda-forge\n"
109+
+ "dependencies:\n"
110+
+ " - python=3.10\n"
111+
+ " - cudatoolkit=11.2\n"
112+
+ " - cudnn=8.1.0\n"
113+
+ " - numpy<1.24\n"
114+
+ " - pip\n"
115+
+ " - pip:\n"
116+
+ " - numpy<1.24\n"
117+
+ " - tensorflow==2.10\n"
118+
+ " - stardist==0.8.5\n"
119+
+ " - appose\n";
120+
}
121+
122+
@Override
123+
String generateScript()
124+
{
125+
return "import numpy as np" + "\n"
126+
+ "import appose" + "\n"
127+
+ "from csbdeep.utils import normalize" + "\n"
128+
+ getImportStarDistCommand()
129+
+ "\n"
130+
+ "task.update(message=\"Imports completed\")" + "\n"
131+
+ "np.random.seed(6)" + "\n"
132+
+ "axes_normalize = (0, 1, 2)" + "\n"
133+
+ "\n"
134+
+ "task.update(message=\"Loading StarDist pretrained 3D model\")" + "\n"
135+
+ getLoadModelCommand()
136+
+ "image_ndarray = image.ndarray()" + "\n"
137+
+ "image_normalized = normalize(image_ndarray, 1, 99.8, axis=axes_normalize)" + "\n"
138+
+ "task.update(message=\"Image shape:\" + str(image_normalized.shape))" + "\n"
139+
+ "\n"
140+
+ "guessed_tiles = model._guess_n_tiles(image_normalized)" + "\n"
141+
+ "task.update(message=\"Guessed tiles: \" + str(guessed_tiles))" + "\n"
142+
+ "\n"
143+
+ "label_image, details = model.predict_instances(image_normalized, axes='ZYX', n_tiles=guessed_tiles)" + "\n"
144+
+ "shared = appose.NDArray(image.dtype, image.shape)" + "\n"
145+
+ "shared.ndarray()[:] = label_image" + "\n"
146+
+ "task.outputs['label_image'] = shared" + "\n";
147+
}
148+
149+
public ModelType getModelType()
150+
{
151+
return modelType;
152+
}
153+
154+
private String getImportStarDistCommand()
155+
{
156+
if ( modelType.getModelPath() == null )
157+
{
158+
if ( dataIs2D )
159+
return "from stardist.models import StarDist2D" + "\n ";
160+
return "from stardist.models import StarDist3D" + "\n ";
161+
}
162+
if ( modelType.is2D() )
163+
return "from stardist.models import StarDist2D" + "\n ";
164+
return "from stardist.models import StarDist3D" + "\n ";
165+
}
166+
167+
private String getLoadModelCommand()
168+
{
169+
if ( modelType.getModelPath() == null )
170+
{
171+
if ( dataIs2D )
172+
return "model = StarDist2D.from_pretrained('2D_demo')" + "\n";
173+
else
174+
return "model = StarDist3D.from_pretrained('3D_demo')" + "\n";
175+
}
176+
String starDistModel = modelType.is2D() ? "StarDist2D" : "StarDist3D";
177+
return "model = " + starDistModel + "(None, name='" + installationFolderName + "', basedir=r\"models" + File.separator
178+
+ modelType.getModelPath() + "\")"
179+
+ "\n";
180+
}
181+
182+
public enum ModelType
183+
{
184+
PLANT_NUCLEI_3D( "StarDist Plant Nuclei 3D ResNet", "stardist-plant-nuclei-3d", false ),
185+
FLUO_2D( "StarDist Fluorescence Nuclei Segmentation", "stardist-fluo-2d", true ),
186+
H_E( "StarDist H&E Nuclei Segmentation", "stardist-h-e-nuclei", true ),
187+
DEMO( "StarDist Demo", null, null );
188+
189+
private final String modelName;
190+
191+
private final String modelPath;
192+
193+
private final Boolean is2D;
194+
195+
ModelType( final String modelName, final String modelPath, final Boolean is2D )
196+
{
197+
this.modelName = modelName;
198+
this.modelPath = modelPath;
199+
this.is2D = is2D;
200+
}
201+
202+
public String getModelName()
203+
{
204+
return modelName;
205+
}
206+
207+
public String getModelPath()
208+
{
209+
return modelPath;
210+
}
211+
212+
public Boolean is2D()
213+
{
214+
return is2D;
215+
}
216+
217+
public String getDisplayName()
218+
{
219+
String dimensionality;
220+
if ( is2D == null )
221+
dimensionality = " (2D/3D)";
222+
else if ( is2D )
223+
dimensionality = " (2D)";
224+
else
225+
dimensionality = " (3D)";
226+
227+
return modelName + dimensionality;
228+
}
229+
230+
public static ModelType fromString( final String modelName )
231+
{
232+
for ( ModelType type : ModelType.values() )
233+
{
234+
if ( type.modelName.equalsIgnoreCase( modelName ) )
235+
{
236+
return type;
237+
}
238+
}
239+
throw new IllegalArgumentException( "No enum constant for model name: " + modelName );
240+
}
241+
}
242+
}

src/main/java/org/mastodon/mamut/detection/StarDist3D.java

Lines changed: 0 additions & 57 deletions
This file was deleted.

0 commit comments

Comments
 (0)