Skip to content

Commit bea66ce

Browse files
authored
Merge pull request #6 from petebankhead/dev
Updates for v0.2.0
2 parents a4758a9 + fa08b73 commit bea66ce

File tree

4 files changed

+186
-19
lines changed

4 files changed

+186
-19
lines changed

CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
## Version 0.2.0
2+
3+
* Improve conversion of `NDArray` to more data types
4+
* Add `DjlTools.getXXX()` methods to get ints, floats, doubles, longs and booleans
5+
* Estimate output size in `DjlDnnModel` if shape doesn't match NDLayout
6+
* This relaxes the assumption that the output layout should match the input
7+
* New `DjlTools.get/setOverrideDevice()` methods to override DJL's default device selection
8+
* Primarily intended to explore `Device.fromName('mps')` on Apple Silicon (which sometimes works, sometimes doesn't...)
9+
10+
11+
## Version 0.1.0
12+
13+
* Initial release

build.gradle

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,35 @@ plugins {
33
id 'maven-publish'
44
// To manage included native libraries
55
alias(libs.plugins.javacpp)
6+
// Need to use this instead (without version) if running as a QuPath subproject
7+
// id 'org.bytedeco.gradle-javacpp-platform'
68
}
79

810
ext.moduleName = 'qupath.extension.djl'
911
ext.qupathVersion = gradle.ext.qupathVersion
1012

1113
description = 'QuPath extension to use Deep Java Library'
12-
version = "0.1.0"
14+
version = "0.2.0"
1315

1416
def djlVersion = libs.versions.deepJavaLibrary.get()
1517

18+
repositories {
19+
// Use this only for local development!
20+
// mavenLocal()
21+
22+
mavenCentral()
23+
24+
maven {
25+
url "https://maven.scijava.org/content/repositories/releases"
26+
}
27+
28+
maven {
29+
url "https://maven.scijava.org/content/repositories/snapshots"
30+
}
31+
32+
}
33+
34+
1635
dependencies {
1736
implementation "io.github.qupath:qupath-gui-fx:${qupathVersion}"
1837

src/main/java/qupath/ext/djl/DjlDnnModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ public NDList toBlob(Mat... mats) {
207207
@Override
208208
public List<Mat> fromBlob(NDList blob) {
209209
String layout;
210-
if (ndLayout == null && !blob.isEmpty())
210+
if ((ndLayout == null || ndLayout.length() != blob.singletonOrThrow().getShape().dimension()) && !blob.isEmpty())
211211
layout = estimateOutputLayout(blob.get(0));
212212
else
213213
layout = ndLayout;

src/main/java/qupath/ext/djl/DjlTools.java

Lines changed: 152 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@
2222
import java.util.Arrays;
2323
import java.util.Collection;
2424
import java.util.Collections;
25+
import java.util.HashMap;
2526
import java.util.HashSet;
2627
import java.util.List;
2728
import java.util.Map;
2829
import java.util.Set;
2930

31+
import ai.djl.Device;
3032
import org.bytedeco.javacpp.Loader;
3133
import org.bytedeco.javacpp.PointerScope;
3234
import org.bytedeco.javacpp.indexer.BooleanIndexer;
@@ -131,6 +133,12 @@ public class DjlTools {
131133
* having to try to instantiate a new one.
132134
*/
133135
public static Set<String> loadedEngines = new HashSet<>();
136+
137+
/**
138+
* Default devices for each engine.
139+
* This can be used to override the default used by DJL.
140+
*/
141+
private static Map<String, Device> defaultDevices = new HashMap<>();
134142

135143
static Set<String> ALL_ENGINES = Set.of(
136144
ENGINE_DLR, ENGINE_LIGHTGBM, ENGINE_MXNET, ENGINE_ONNX_RUNTIME, ENGINE_PADDLEPADDLE,
@@ -308,29 +316,66 @@ private static ZooModel<NDList, NDList> loadModel(String engineName, String urls
308316
.optModelUrls(urls)
309317
.optProgress(new ProgressBar());
310318

311-
boolean foundEngine = false;
319+
String selectedEngine = null;
312320
if (engineName != null) {
313321
if (Engine.getAllEngines().contains(engineName)) {
314-
builder = builder.optEngine(engineName);
315-
foundEngine = true;
322+
selectedEngine = engineName;
316323
}
317324
}
318325

319326
// Try to figure out the engine name
320-
if (!foundEngine) {
327+
if (selectedEngine == null) {
321328
var urlString = urls.toString().toLowerCase();
322329
if (urlString.endsWith(".onnx") && Engine.hasEngine("OnnxRuntime"))
323-
builder = builder.optEngine("OnnxRuntime");
330+
selectedEngine = "OnnxRuntime";
324331
else if ((urlString.endsWith("pytorch") || urlString.endsWith(".pt")) && Engine.hasEngine("PyTorch"))
325-
builder = builder.optEngine("PyTorch");
332+
selectedEngine = "PyTorch";
326333
else if ((urlString.endsWith(".pb") || urlString.endsWith("tf_savedmodel.zip") || urlString.endsWith("tf_savedmodel")) && Engine.hasEngine("TensorFlow"))
327-
builder = builder.optEngine("TensorFlow");
334+
selectedEngine = "TensorFlow";
335+
}
336+
337+
if (selectedEngine != null) {
338+
builder.optEngine(selectedEngine);
339+
var device = defaultDevices.getOrDefault(selectedEngine, null);
340+
if (device != null) {
341+
builder.optDevice(device);
342+
builder.optOption("mapLocation", "true");
343+
}
328344
}
329345

330346
var criteria = builder.build();
331347
return ModelZoo.loadModel(criteria);
332348
}
333349

350+
351+
/**
352+
* Set the default device for the specified engine.
353+
* This will be used only whenever the model is build using this class, overriding
354+
* DJL's default.
355+
* <p>
356+
* Note that the default device chosen automatically by DJL is usually fine,
357+
* and so it is generally not necessary to set this.
358+
* However it can be useful for exploring, or if DJL does not use the device you want.
359+
* @param engineName
360+
* @param device
361+
*/
362+
public static void setOverrideDevice(String engineName, Device device) {
363+
if (device == null)
364+
defaultDevices.remove(engineName);
365+
else
366+
defaultDevices.put(engineName, device);
367+
}
368+
369+
/**
370+
* Get the default device for the specified engine, which overrides DJL's default device for
371+
* the specified engine.
372+
* @param engineName
373+
* @return the default device, or null if not set
374+
*/
375+
public static Device getOverrideDevice(String engineName) {
376+
return defaultDevices.getOrDefault(engineName, null);
377+
}
378+
334379
// static ZooModel<Mat, Mat> loadModelCV(URI uri, String ndLayout) throws ModelNotFoundException, MalformedModelException, IOException {
335380
// var criteria = Criteria.builder()
336381
// .setTypes(Mat.class, Mat.class)
@@ -373,9 +418,13 @@ public static NDArray matToNDArray(NDManager manager, Mat mat, String ndLayout)
373418
var buffer = mat.createBuffer();
374419
array = manager.create(buffer, shape, dataType);
375420
} else {
421+
var shapeDims = shape.getShape();
422+
shapeDims[indC] = 1;
423+
var shapeChannel = new Shape(shapeDims, shape.getLayout());
424+
376425
for (var mat2 : OpenCVTools.splitChannels(mat)) {
377-
var buffer = mat2.createBuffer();
378-
var arrayTemp = manager.create(buffer, shape, dataType);
426+
var buffer = mat2.createBuffer();
427+
var arrayTemp = manager.create(buffer, shapeChannel, dataType);
379428
if (array == null)
380429
array = arrayTemp;
381430
else {
@@ -469,26 +518,112 @@ public static Mat ndArrayToMat(NDArray array, String ndLayout, boolean doSqueeze
469518
if (indexer instanceof ByteIndexer) {
470519
((ByteIndexer) indexer).put(0L, array.toByteArray());
471520
} else if (indexer instanceof UByteIndexer) {
472-
((UByteIndexer) indexer).put(0L, array.toUint8Array());
521+
((UByteIndexer) indexer).put(0L, getInts(array));
473522
} else if (indexer instanceof UShortIndexer) {
474-
((UShortIndexer) indexer).put(0L, array.toIntArray());
523+
((UShortIndexer) indexer).put(0L, getInts(array));
475524
} else if (indexer instanceof IntIndexer) {
476-
((IntIndexer) indexer).put(0L, array.toIntArray());
525+
((IntIndexer) indexer).put(0L, getInts(array));
477526
} else if (indexer instanceof FloatIndexer) {
478-
((FloatIndexer) indexer).put(0L, array.toFloatArray());
527+
((FloatIndexer) indexer).put(0L, getFloats(array));
479528
} else if (indexer instanceof HalfIndexer) {
480-
((HalfIndexer) indexer).put(0L, array.toFloatArray());
529+
((HalfIndexer) indexer).put(0L,getFloats(array));
481530
} else if (indexer instanceof DoubleIndexer) {
482-
((DoubleIndexer) indexer).put(0L, array.toDoubleArray());
531+
((DoubleIndexer) indexer).put(0L, getDoubles(array));
483532
} else if (indexer instanceof LongIndexer) {
484-
((LongIndexer) indexer).put(0L, array.toLongArray());
533+
((LongIndexer) indexer).put(0L, getLongs(array));
485534
} else if (indexer instanceof BooleanIndexer) {
486-
((BooleanIndexer) indexer).put(0L, array.toBooleanArray());
535+
((BooleanIndexer) indexer).put(0L, getBooleans(array));
487536
} else
488537
throw new IllegalArgumentException("Unable to convert array " + array + " to Mat");
489538
}
490539
return mat;
491540
}
541+
542+
/**
543+
* Extract array values as longs, converting if necessary.
544+
* @param array
545+
* @return
546+
*/
547+
public static long[] getLongs(NDArray array) {
548+
if (array.getDataType() == DataType.INT64) {
549+
try {
550+
return array.toLongArray();
551+
} catch (Exception e) {
552+
logger.error("Exception requesting longs from NDArray");
553+
}
554+
}
555+
return array.toType(DataType.INT64, true).toLongArray();
556+
}
557+
558+
/**
559+
* Extract array values as booleans, converting if necessary.
560+
* @param array
561+
* @return
562+
*/
563+
private static boolean[] getBooleans(NDArray array) {
564+
if (array.getDataType() == DataType.BOOLEAN) {
565+
try {
566+
return array.toBooleanArray();
567+
} catch (Exception e) {
568+
logger.error("Exception requesting ints from NDArray");
569+
}
570+
}
571+
return array.toType(DataType.BOOLEAN, true).toBooleanArray();
572+
}
573+
574+
/**
575+
* Extract array values as ints, converting if necessary.
576+
* @param array
577+
* @return
578+
*/
579+
private static int[] getInts(NDArray array) {
580+
if (array.getDataType() == DataType.INT32) {
581+
try {
582+
return array.toIntArray();
583+
} catch (Exception e) {
584+
logger.error("Exception requesting ints from NDArray");
585+
}
586+
} else if (array.getDataType() == DataType.UINT8) {
587+
try {
588+
return array.toUint8Array();
589+
} catch (Exception e) {
590+
logger.error("Exception requesting ints from NDArray");
591+
}
592+
}
593+
return array.toType(DataType.INT32, true).toIntArray();
594+
}
595+
596+
/**
597+
* Extract array values as doubles, converting if necessary.
598+
* @param array
599+
* @return
600+
*/
601+
private static double[] getDoubles(NDArray array) {
602+
if (array.getDataType() == DataType.FLOAT64) {
603+
try {
604+
return array.toDoubleArray();
605+
} catch (Exception e) {
606+
logger.error("Exception requesting doubles from NDArray");
607+
}
608+
}
609+
return array.toType(DataType.FLOAT64, true).toDoubleArray();
610+
}
611+
612+
/**
613+
* Extract array values as floats, converting if necessary.
614+
* @param array
615+
* @return
616+
*/
617+
private static float[] getFloats(NDArray array) {
618+
if (array.getDataType() == DataType.FLOAT32 || array.getDataType() == DataType.FLOAT16) {
619+
try {
620+
return array.toFloatArray();
621+
} catch (Exception e) {
622+
logger.error("Exception requesting floats from NDArray", e);
623+
}
624+
}
625+
return array.toType(DataType.FLOAT32, true).toFloatArray();
626+
}
492627

493628

494629
static class MatTranslator implements Translator<Mat, Mat> {

0 commit comments

Comments
 (0)