diff --git a/src/main/java/qupath/ext/wsinfer/WSInfer.java b/src/main/java/qupath/ext/wsinfer/WSInfer.java index 730e248..ed51b4e 100644 --- a/src/main/java/qupath/ext/wsinfer/WSInfer.java +++ b/src/main/java/qupath/ext/wsinfer/WSInfer.java @@ -193,8 +193,7 @@ public static void runInference(ImageData imageData, WSInferModel } } - boolean applySoftmax = true; - Translator translator = buildTranslator(wsiModel, pipeline, applySoftmax); + Translator translator = buildTranslator(wsiModel, pipeline); Criteria criteria = buildCriteria(wsiModel, translator, device); List classNames = wsiModel.getConfiguration().getClassNames(); long startTime = System.currentTimeMillis(); @@ -308,13 +307,13 @@ private static boolean isMPS(Device device) { } - private static Translator buildTranslator(WSInferModel wsiModel, Pipeline pipeline, boolean applySoftmax) { + private static Translator buildTranslator(WSInferModel wsiModel, Pipeline pipeline) { // We should use ImageClassificationTranslator.builder() in the future if this is updated to work with MPS // (See javadocs for MpsSupport.WSInferClassificationTranslator for details) // ImageClassificationTranslator.Builder builder = ImageClassificationTranslator.builder() return MpsSupport.WSInferClassificationTranslator.builder() .optSynset(wsiModel.getConfiguration().getClassNames()) - .optApplySoftmax(applySoftmax) + .optApplySoftmax(wsiModel.getConfiguration().isApplySoftmax()) .setPipeline(pipeline) .build(); } diff --git a/src/main/java/qupath/ext/wsinfer/models/WSInferModelConfiguration.java b/src/main/java/qupath/ext/wsinfer/models/WSInferModelConfiguration.java index 2bdc56b..5655555 100644 --- a/src/main/java/qupath/ext/wsinfer/models/WSInferModelConfiguration.java +++ b/src/main/java/qupath/ext/wsinfer/models/WSInferModelConfiguration.java @@ -45,6 +45,9 @@ public class WSInferModelConfiguration { @SerializedName("spacing_um_px") private double spacingUmPx; + @SerializedName("apply_softmax") + private boolean applySoftmax = true; + private List transform; /** @@ -78,4 +81,12 @@ public double getPatchSizePixels() { public double getSpacingMicronPerPixel() { return spacingUmPx; } + + /** + * Whether to apply softmax to model output + * @return + */ + public boolean isApplySoftmax() { + return applySoftmax; + } }