Skip to content

Commit a0ec5d0

Browse files
committed
Add Apply SoftMax as Model Configuration option
1 parent dfab680 commit a0ec5d0

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

src/main/java/qupath/ext/wsinfer/WSInfer.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ public static void runInference(ImageData<BufferedImage> imageData, WSInferModel
194194
}
195195

196196
boolean applySoftmax = true;
197-
Translator translator = buildTranslator(wsiModel, pipeline, applySoftmax);
197+
Translator translator = buildTranslator(wsiModel, pipeline);
198198
Criteria<Image, Classifications> criteria = buildCriteria(wsiModel, translator, device);
199199
List<String> classNames = wsiModel.getConfiguration().getClassNames();
200200
long startTime = System.currentTimeMillis();
@@ -308,13 +308,13 @@ private static boolean isMPS(Device device) {
308308
}
309309

310310

311-
private static Translator<Image, Classifications> buildTranslator(WSInferModel wsiModel, Pipeline pipeline, boolean applySoftmax) {
311+
private static Translator<Image, Classifications> buildTranslator(WSInferModel wsiModel, Pipeline pipeline) {
312312
// We should use ImageClassificationTranslator.builder() in the future if this is updated to work with MPS
313313
// (See javadocs for MpsSupport.WSInferClassificationTranslator for details)
314314
// ImageClassificationTranslator.Builder builder = ImageClassificationTranslator.builder()
315315
return MpsSupport.WSInferClassificationTranslator.builder()
316316
.optSynset(wsiModel.getConfiguration().getClassNames())
317-
.optApplySoftmax(applySoftmax)
317+
.optApplySoftmax(wsiModel.getConfiguration().isApplySoftmax())
318318
.setPipeline(pipeline)
319319
.build();
320320
}

src/main/java/qupath/ext/wsinfer/models/WSInferModelConfiguration.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ public class WSInferModelConfiguration {
4545
@SerializedName("spacing_um_px")
4646
private double spacingUmPx;
4747

48+
@SerializedName("apply_softmax")
49+
private boolean applySoftmax = true;
50+
4851
private List<WSInferTransform> transform;
4952

5053
/**
@@ -78,4 +81,12 @@ public double getPatchSizePixels() {
7881
public double getSpacingMicronPerPixel() {
7982
return spacingUmPx;
8083
}
84+
85+
/**
86+
* Whether to apply softmax to model output
87+
* @return
88+
*/
89+
public boolean isApplySoftmax() {
90+
return applySoftmax;
91+
}
8192
}

0 commit comments

Comments
 (0)