Skip to content

Commit a3d4295

Browse files
Merge pull request #142 from mastodon-sc/trackastra
Add Trackastra Linker
2 parents 6d5a936 + 2a5ee17 commit a3d4295

30 files changed

Lines changed: 1993 additions & 1 deletion

src/main/java/org/mastodon/mamut/feature/AbstractSerialFeatureComputer.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
import java.util.Collection;
3737

3838
/**
39-
* Abstract class for computing features in a serial way.
39+
* Abstract class for computing features in a serial way, i.e., one vertex after
40+
* the other.
4041
* @param <V> the type of vertex.
4142
*/
4243
public abstract class AbstractSerialFeatureComputer< V extends Vertex< ? > > extends AbstractResettableFeatureComputer

src/main/java/org/mastodon/mamut/io/exporter/labelimage/ExportLabelImageController.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,14 @@
3939
import net.imagej.axis.Axes;
4040
import net.imagej.axis.CalibratedAxis;
4141
import net.imagej.axis.DefaultLinearAxis;
42+
import net.imglib2.RandomAccess;
43+
import net.imglib2.RandomAccessibleInterval;
4244
import net.imglib2.cache.img.DiskCachedCellImg;
4345
import net.imglib2.cache.img.DiskCachedCellImgFactory;
4446
import net.imglib2.cache.img.DiskCachedCellImgOptions;
4547
import net.imglib2.img.Img;
4648
import net.imglib2.img.display.imagej.ImageJFunctions;
49+
import net.imglib2.loops.LoopBuilder;
4750
import net.imglib2.realtransform.AffineTransform3D;
4851
import net.imglib2.type.numeric.RealType;
4952
import net.imglib2.type.numeric.real.FloatType;
@@ -73,6 +76,7 @@
7376
import java.util.Arrays;
7477
import java.util.List;
7578
import java.util.concurrent.locks.ReentrantReadWriteLock;
79+
import java.util.function.BiConsumer;
7680
import java.util.function.Consumer;
7781

7882
public class ExportLabelImageController
@@ -158,6 +162,9 @@ public void saveLabelImageToFile(
158162
int targetFrameId = sourceFrameId / frameRateReduction;
159163
logger.trace( "sourceFrameId: {}, targetFrameId: {}", sourceFrameId, targetFrameId );
160164
IntervalView< FloatType > frame = Views.hyperSlice( img, 3, targetFrameId );
165+
166+
// RandomAccessibleInterval< RealType< ? > > hdf5 = source.getSource( targetFrameId, mipMapLevel ); // use to export hdf5 source to tiffs
167+
// LoopBuilder.setImages( hdf5, frame ).forEachPixel( ( realType, floatType ) -> floatType.set( realType.getRealFloat() ) );
161168
AbstractSource< FloatType > frameSource =
162169
new RandomAccessibleIntervalSource<>( frame, new FloatType(), transform, "Ellipsoids" );
163170
final EllipsoidIterable< FloatType > ellipsoidIterable = new EllipsoidIterable<>( frameSource );
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package org.mastodon.mamut.io.exporter.labelimage;
2+
3+
import java.lang.invoke.MethodHandles;
4+
5+
import net.imglib2.RandomAccessibleInterval;
6+
import net.imglib2.img.array.ArrayImgs;
7+
import net.imglib2.realtransform.AffineTransform3D;
8+
import net.imglib2.type.numeric.integer.IntType;
9+
10+
import org.mastodon.mamut.feature.EllipsoidIterable;
11+
import org.mastodon.mamut.model.Spot;
12+
import org.mastodon.mamut.util.ImgUtils;
13+
import org.mastodon.spatial.SpatioTemporalIndex;
14+
import org.slf4j.Logger;
15+
import org.slf4j.LoggerFactory;
16+
17+
import bdv.util.AbstractSource;
18+
import bdv.util.RandomAccessibleIntervalSource;
19+
20+
public class ExportLabelImageUtils
21+
{
22+
private static final Logger logger = LoggerFactory.getLogger( MethodHandles.lookup().lookupClass() );
23+
24+
private ExportLabelImageUtils()
25+
{
26+
// prevent instantiation
27+
}
28+
29+
public static RandomAccessibleInterval< IntType > getLabelImageFromSpots( final AffineTransform3D transform, long[] dimensions,
30+
final int level, final int timepoint, final SpatioTemporalIndex< Spot > spatioTemporalIndex )
31+
{
32+
RandomAccessibleInterval< IntType > masksImage = ArrayImgs.ints( dimensions );
33+
AbstractSource< IntType > masksSource = new RandomAccessibleIntervalSource<>( masksImage, new IntType(), transform, "masks" );
34+
final EllipsoidIterable< IntType > ellipsoidIterable = new EllipsoidIterable<>( masksSource );
35+
int spotCount = 0;
36+
for ( Spot spot : spatioTemporalIndex.getSpatialIndex( timepoint ) )
37+
{
38+
ellipsoidIterable.reset( spot, level );
39+
ellipsoidIterable.forEach( pixel -> pixel.set( spot.getInternalPoolIndex() + 1 ) );
40+
spotCount++;
41+
}
42+
String masksDimensions = ImgUtils.getImageDimensionsAsString( masksImage );
43+
logger.info( "Wrote {} spot(s) into image with dimensions: {} and type: {} ", spotCount, masksDimensions,
44+
masksSource.getType().getClass().getSimpleName() );
45+
return masksImage;
46+
}
47+
}
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
package org.mastodon.mamut.linking.trackastra;
2+
3+
import static org.mastodon.mamut.detection.DeepLearningDetectorKeys.KEY_LEVEL;
4+
import static org.mastodon.mamut.linking.trackastra.TrackastraUtils.KEY_SOURCE;
5+
import static org.mastodon.mamut.linking.trackastra.TrackastraUtils.KEY_WINDOW_SIZE;
6+
import static org.mastodon.tracking.detection.DetectorKeys.KEY_MAX_TIMEPOINT;
7+
import static org.mastodon.tracking.detection.DetectorKeys.KEY_MIN_TIMEPOINT;
8+
9+
import java.lang.invoke.MethodHandles;
10+
import java.util.List;
11+
12+
import net.imglib2.RealLocalizable;
13+
import net.imglib2.algorithm.Benchmark;
14+
import net.imglib2.util.Cast;
15+
16+
import org.apache.commons.lang.StringUtils;
17+
import org.mastodon.Ref;
18+
import org.mastodon.graph.Edge;
19+
import org.mastodon.graph.ReadOnlyGraph;
20+
import org.mastodon.graph.Vertex;
21+
import org.mastodon.mamut.linking.trackastra.appose.computation.LinkPrediction;
22+
import org.mastodon.mamut.linking.trackastra.appose.types.RegionProps;
23+
import org.mastodon.mamut.linking.trackastra.appose.computation.RegionPropsComputation;
24+
import org.mastodon.mamut.linking.trackastra.appose.types.SingleTimepointRegionProps;
25+
import org.mastodon.spatial.HasTimepoint;
26+
import org.mastodon.spatial.SpatioTemporalIndex;
27+
import org.mastodon.tracking.linking.graph.AbstractGraphParticleLinkerOp;
28+
import org.mastodon.tracking.linking.graph.GraphParticleLinkerOp;
29+
import org.scijava.plugin.Plugin;
30+
import org.slf4j.Logger;
31+
import org.slf4j.LoggerFactory;
32+
33+
import bdv.viewer.Source;
34+
35+
@Plugin( type = GraphParticleLinkerOp.class )
36+
public class TrackastraLinker< V extends Vertex< E > & HasTimepoint & RealLocalizable & Ref< V >, E extends Edge< V > >
37+
extends AbstractGraphParticleLinkerOp< V, E >
38+
implements Benchmark
39+
{
40+
41+
private static final Logger log = LoggerFactory.getLogger( MethodHandles.lookup().lookupClass() );
42+
43+
private long processingTime;
44+
45+
@Override
46+
public void mutate1( final ReadOnlyGraph< V, E > graph, final SpatioTemporalIndex< V > index )
47+
{
48+
long start = System.currentTimeMillis();
49+
try
50+
{
51+
List< SingleTimepointRegionProps > regionProps = computeRegionProps( index );
52+
runLinkPrediction( index, regionProps );
53+
ok = true;
54+
}
55+
catch ( TrackastraLinkingException e )
56+
{
57+
Throwable cause = e.getCause();
58+
String msg = "";
59+
if ( cause != null )
60+
msg = cause.getMessage();
61+
62+
log.error( "Error during Trackastra Linking: {}. Cause: {}.", StringUtils.defaultString( e.getMessage(), e.toString() ), msg );
63+
ok = false;
64+
errorMessage = e.getMessage() + ( ( msg != null && !msg.isEmpty() ) ? " Caused by: " + msg : "" );
65+
}
66+
finally
67+
{
68+
processingTime = System.currentTimeMillis() - start;
69+
}
70+
}
71+
72+
private List< SingleTimepointRegionProps > computeRegionProps( final SpatioTemporalIndex< V > index ) throws TrackastraLinkingException
73+
{
74+
log.info( "Computing region props for Trackastra" );
75+
String model = ( ( TrackastraModel ) settings.get( TrackastraUtils.KEY_MODEL ) ).getName();
76+
int windowSize = ( Integer ) settings.get( KEY_WINDOW_SIZE );
77+
int minTimepoint = ( int ) settings.get( KEY_MIN_TIMEPOINT );
78+
int maxTimepoint = ( int ) settings.get( KEY_MAX_TIMEPOINT );
79+
int level = ( int ) settings.get( KEY_LEVEL );
80+
Source< ? > source = ( Source< ? > ) settings.get( KEY_SOURCE );
81+
82+
log.info( "Source: {}", source );
83+
84+
int timeRange = maxTimepoint - minTimepoint + 1;
85+
if ( windowSize > timeRange )
86+
throw new IllegalArgumentException(
87+
String.format( "Window size (%d) exceeds time range (%d). Adjust window size or time range.", windowSize, timeRange ) );
88+
89+
try (RegionPropsComputation computation = new RegionPropsComputation( logger, model ))
90+
{
91+
return computation.computeRegionPropsForSource( source, level, Cast.unchecked( index ), minTimepoint, maxTimepoint );
92+
}
93+
catch ( Exception e )
94+
{
95+
throw new TrackastraLinkingException( "Failed to compute region props", e );
96+
}
97+
}
98+
99+
private void runLinkPrediction( final SpatioTemporalIndex< V > index, final List< SingleTimepointRegionProps > regionProps )
100+
throws TrackastraLinkingException
101+
{
102+
log.info( "Performing Trackastra link prediction" );
103+
try (RegionProps props = new RegionProps( regionProps );
104+
LinkPrediction prediction =
105+
new LinkPrediction( settings, Cast.unchecked( index ), Cast.unchecked( edgeCreator ), props, logger ))
106+
{
107+
prediction.predictAndCreateLinks();
108+
}
109+
catch ( Exception e )
110+
{
111+
throw new TrackastraLinkingException( "Failed to perform link prediction", e );
112+
}
113+
}
114+
115+
@Override
116+
public long getProcessingTime()
117+
{
118+
return processingTime;
119+
}
120+
121+
@Override
122+
public boolean isSuccessful()
123+
{
124+
return ok;
125+
}
126+
127+
@Override
128+
public String getErrorMessage()
129+
{
130+
return errorMessage;
131+
}
132+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package org.mastodon.mamut.linking.trackastra;
2+
3+
import java.util.Map;
4+
5+
import org.mastodon.mamut.model.ModelGraph;
6+
import org.mastodon.mamut.model.Spot;
7+
import org.mastodon.spatial.SpatioTemporalIndex;
8+
import org.mastodon.tracking.mamut.linking.AbstractSpotLinkerOp;
9+
import org.mastodon.tracking.mamut.linking.SpotLinkerOp;
10+
import org.scijava.Priority;
11+
import org.scijava.plugin.Plugin;
12+
13+
@Plugin( type = SpotLinkerOp.class, priority = Priority.LOW, name = "Trackastra Linker", description = "<html>"
14+
+ "This linker uses Trackastra for linking. Trackastra has been published in:<br>"
15+
+ "<i>TRACKASTRA: Transformer-Based Cell Tracking for Live-Cell Microscopy </i> - "
16+
+ "<i>Gallusser, B. & Weigert, M.</i>, Computer Vision – ECCV 2024.<br><br>"
17+
+ "Trackastra uses a transformer based architecture to directly learn pairwise associations of cells within a temporal window from annotated data.<br>"
18+
+ "Trackastra can account for dividing objects such as cells and allows for accurate tracking even with simple greedy linking.<br>"
19+
+ "The architecture operates solely on the full spatio-temporal context of detections.<br><br>"
20+
+ "<strong>When this linker method is used for the first time, internet connection is needed, since an internal installation process is started. The installation consumes ~2.5GB hard disk space.</strong><br>"
21+
+ "</html>" )
22+
public class TrackastraLinkerMamut extends AbstractSpotLinkerOp
23+
{
24+
@Override
25+
public void mutate1( final ModelGraph graph, final SpatioTemporalIndex< Spot > spots )
26+
{
27+
exec( graph, spots, TrackastraLinker.class );
28+
}
29+
30+
@Override
31+
public Map< String, Object > getDefaultSettings()
32+
{
33+
return TrackastraUtils.getDefaultTrackAstraSettingsMap();
34+
}
35+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package org.mastodon.mamut.linking.trackastra;
2+
3+
public class TrackastraLinkingException extends Exception
4+
{
5+
public TrackastraLinkingException( String message )
6+
{
7+
super( message );
8+
}
9+
10+
public TrackastraLinkingException( String message, Throwable cause )
11+
{
12+
super( message, cause );
13+
}
14+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package org.mastodon.mamut.linking.trackastra;
2+
3+
public enum TrackastraMode
4+
{
5+
GREEDY( "greedy", "Greedy linking with divisions" ),
6+
GREEDY_NODIV( "gredy_nodiv", "Greedy linking without divisions" );
7+
8+
private final String name;
9+
10+
private final String description;
11+
12+
TrackastraMode( final String name, final String description )
13+
{
14+
this.name = name;
15+
this.description = description;
16+
}
17+
18+
@Override
19+
public String toString()
20+
{
21+
return description;
22+
}
23+
24+
public String getName()
25+
{
26+
return name;
27+
}
28+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package org.mastodon.mamut.linking.trackastra;
2+
3+
public enum TrackastraModel
4+
{
5+
CTC( true, "ctc", "Cell Tracking Challenge (2D+3D)" ),
6+
GENERAL_2D( false, "general_2d", "General Model (2D)" );
7+
8+
private final boolean is3D;
9+
10+
private final String name;
11+
12+
private final String displayName;
13+
14+
TrackastraModel( final boolean is3D, final String name, final String displayName )
15+
{
16+
this.is3D = is3D;
17+
this.name = name;
18+
this.displayName = displayName;
19+
}
20+
21+
public boolean is3D()
22+
{
23+
return is3D;
24+
}
25+
26+
public String getName()
27+
{
28+
return name;
29+
}
30+
31+
@Override
32+
public String toString()
33+
{
34+
return displayName;
35+
}
36+
37+
}

0 commit comments

Comments
 (0)