Skip to content

Commit 12859be

Browse files
author
Stefan Hahmann
committed
Add StarDist3D detector
1 parent 7e41818 commit 12859be

6 files changed

Lines changed: 610 additions & 60 deletions

File tree

pom.xml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,20 @@
150150
<version>0.4.0-SNAPSHOT</version>
151151
</dependency>
152152

153-
<!-- mastodon tracking as base for python based detection -->
153+
<!-- mastodon tracking as base for Python-based detection -->
154154
<dependency>
155155
<groupId>${mastodon.group}</groupId>
156156
<artifactId>mastodon-tracking</artifactId>
157157
<version>${mastodon-tracking.version}</version>
158158
</dependency>
159159

160+
<!-- dl-modelrunner -->
161+
<dependency>
162+
<groupId>io.bioimage</groupId>
163+
<artifactId>dl-modelrunner</artifactId>
164+
<version>0.5.10</version>
165+
</dependency>
166+
160167
<!-- Test dependencies -->
161168
<dependency>
162169
<groupId>org.junit.jupiter</groupId>
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
package org.mastodon.mamut.detection;
2+
3+
import java.io.File;
4+
import java.io.IOException;
5+
import java.lang.invoke.MethodHandles;
6+
import java.nio.file.Path;
7+
import java.nio.file.Paths;
8+
9+
import org.slf4j.Logger;
10+
import org.slf4j.LoggerFactory;
11+
12+
import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
13+
14+
public class StarDist extends Segmentation3D
15+
{
16+
private static final Logger logger = LoggerFactory.getLogger( MethodHandles.lookup().lookupClass() );
17+
18+
private final ModelType modelType;
19+
20+
private boolean dataIs2D;
21+
22+
private static final String PATH = "stardist";
23+
24+
public StarDist( final ModelType model ) throws IOException, InterruptedException
25+
{
26+
super();
27+
this.modelType = model;
28+
if ( modelType == null )
29+
logger.info( "No star dist model path specified. Using pretrained demo model." );
30+
else
31+
{
32+
Path starDistModelRoot = getStarDistModelRoot();
33+
File directory = starDistModelRoot.toFile();
34+
if ( !directory.exists() )
35+
{
36+
if ( !directory.mkdirs() )
37+
{
38+
throw new RuntimeException( "Failed to create environment directory: " + directory );
39+
}
40+
}
41+
if ( directory.isDirectory() )
42+
{
43+
String[] files = directory.list();
44+
if ( files != null && files.length > 0 )
45+
logger.info( "Found {} files/directories in {}", files.length, directory.getAbsolutePath() );
46+
else
47+
{
48+
try
49+
{
50+
logger.info( "Downloading model to {}", directory.getAbsolutePath() );
51+
BioimageioRepo.connect().downloadByName( modelType.getModelName(), directory.getAbsolutePath() );
52+
}
53+
catch ( IllegalArgumentException e )
54+
{
55+
logger.info( "Exception while downloading model: {}", e.getMessage() );
56+
}
57+
}
58+
}
59+
else
60+
logger.error( "The specified path is not a directory: {}", directory.getAbsolutePath() );
61+
}
62+
}
63+
64+
private Path getStarDistModelRoot()
65+
{
66+
return Paths.get( System.getProperty( "user.home" ), ".local", "share", "stardist", "models", modelType.getModelPath() );
67+
}
68+
69+
@Override
70+
String generateEnvFileContent()
71+
{
72+
return "name: stardist\n"
73+
+ "channels:\n"
74+
+ " - conda-forge\n"
75+
+ "dependencies:\n"
76+
+ " - python=3.10\n"
77+
+ " - cudatoolkit=11.2\n"
78+
+ " - cudnn=8.1.0\n"
79+
+ " - numpy<1.24\n"
80+
+ " - pip\n"
81+
+ " - pip:\n"
82+
+ " - numpy<1.24\n"
83+
+ " - tensorflow==2.10\n"
84+
+ " - stardist==0.8.5\n"
85+
+ " - appose\n";
86+
}
87+
88+
@Override
89+
String generateScript()
90+
{
91+
return "import numpy as np" + "\n"
92+
+ "import appose" + "\n"
93+
+ "from csbdeep.utils import normalize" + "\n"
94+
+ getImportStarDistCommand()
95+
+ "\n"
96+
+ "task.update(message=\"Imports completed\")" + "\n"
97+
+ "np.random.seed(6)" + "\n"
98+
+ "axes_normalize = (0, 1, 2)" + "\n"
99+
+ "\n"
100+
+ "task.update(message=\"Loading StarDist pretrained 3D model\")" + "\n"
101+
+ getLoadModelCommand()
102+
+ "image_ndarray = image.ndarray()" + "\n"
103+
+ "image_normalized = normalize(image_ndarray, 1, 99.8, axis=axes_normalize)" + "\n"
104+
+ "task.update(message=\"Image shape:\" + str(image_normalized.shape))" + "\n"
105+
+ "\n"
106+
+ "guessed_tiles = model._guess_n_tiles(image_normalized)" + "\n"
107+
+ "task.update(message=\"Guessed tiles: \" + str(guessed_tiles))" + "\n"
108+
+ "\n"
109+
+ "label_image, details = model.predict_instances(image_normalized, axes='ZYX', n_tiles=guessed_tiles)" + "\n"
110+
+ "shared = appose.NDArray(image.dtype, image.shape)" + "\n"
111+
+ "shared.ndarray()[:] = label_image" + "\n"
112+
+ "task.outputs['label_image'] = shared" + "\n";
113+
}
114+
115+
public ModelType getModelType()
116+
{
117+
return modelType;
118+
}
119+
120+
private String getImportStarDistCommand()
121+
{
122+
if ( modelType == null )
123+
{
124+
if ( dataIs2D )
125+
return "from stardist.models import StarDist2D" + "\n ";
126+
return "from stardist.models import StarDist3D" + "\n ";
127+
}
128+
if ( modelType.is2D() )
129+
return "from stardist.models import StarDist2D" + "\n ";
130+
return "from stardist.models import StarDist3D" + "\n ";
131+
}
132+
133+
private String getLoadModelCommand()
134+
{
135+
if ( modelType == null )
136+
{
137+
if ( dataIs2D )
138+
return "model = StarDist2D.from_pretrained('2D_demo')" + "\n";
139+
else
140+
return "model = StarDist3D.from_pretrained('3D_demo')" + "\n";
141+
}
142+
String starDistModel = modelType.is2D() ? "StarDist2D" : "StarDist3D";
143+
return "model = " + starDistModel + "(None, name='" + modelType.getModelName() + "', basedir=r\"" + getStarDistModelRoot() + "\")"
144+
+ "\n";
145+
}
146+
147+
public enum ModelType
148+
{
149+
PLANT_NUCLEI_3D( "StarDist Plant Nuclei 3D ResNet", "stardist-plant-nuclei-3d", false ),
150+
FLUO_2D( "StarDist Fluorescence Nuclei Segmentation", "stardist-fluo-2d", true ),
151+
H_E( "StarDist H&E Nuclei Segmentation", "stardist-h-e-nuclei", true ),
152+
DEMO_2D( "StarDist Demo", "stardist-demo-2d", true ),
153+
DEMO_3D( "StarDist Demo", "stardist-demo-3d", false );
154+
155+
private final String modelName;
156+
157+
private final String modelPath;
158+
159+
private final boolean is2D;
160+
161+
ModelType( final String modelName, final String modelPath, final boolean is2D )
162+
{
163+
this.modelName = modelName;
164+
this.modelPath = modelPath;
165+
this.is2D = is2D;
166+
}
167+
168+
public String getModelName()
169+
{
170+
return modelName;
171+
}
172+
173+
public String getModelPath()
174+
{
175+
return modelPath;
176+
}
177+
178+
public boolean is2D()
179+
{
180+
return is2D;
181+
}
182+
183+
@Override
184+
public String toString()
185+
{
186+
String dimensionality = is2D ? " (2D)" : " (3D)";
187+
return modelName + dimensionality;
188+
}
189+
190+
public static ModelType fromString( final String modelName )
191+
{
192+
for ( ModelType type : ModelType.values() )
193+
{
194+
if ( type.modelName.equalsIgnoreCase( modelName ) )
195+
{
196+
return type;
197+
}
198+
}
199+
throw new IllegalArgumentException( "No enum constant for model name: " + modelName );
200+
}
201+
}
202+
}

src/main/java/org/mastodon/mamut/detection/StarDist3D.java

Lines changed: 0 additions & 57 deletions
This file was deleted.

0 commit comments

Comments
 (0)