|
22 | 22 | import java.util.Arrays; |
23 | 23 | import java.util.Collection; |
24 | 24 | import java.util.Collections; |
| 25 | +import java.util.HashMap; |
25 | 26 | import java.util.HashSet; |
26 | 27 | import java.util.List; |
27 | 28 | import java.util.Map; |
28 | 29 | import java.util.Set; |
29 | 30 |
|
| 31 | +import ai.djl.Device; |
30 | 32 | import org.bytedeco.javacpp.Loader; |
31 | 33 | import org.bytedeco.javacpp.PointerScope; |
32 | 34 | import org.bytedeco.javacpp.indexer.BooleanIndexer; |
@@ -131,6 +133,12 @@ public class DjlTools { |
131 | 133 | * having to try to instantiate a new one. |
132 | 134 | */ |
133 | 135 | public static Set<String> loadedEngines = new HashSet<>(); |
| 136 | + |
| 137 | + /** |
| 138 | + * Default devices for each engine. |
| 139 | + * This can be used to override the default used by DJL. |
| 140 | + */ |
| 141 | + private static Map<String, Device> defaultDevices = new HashMap<>(); |
134 | 142 |
|
135 | 143 | static Set<String> ALL_ENGINES = Set.of( |
136 | 144 | ENGINE_DLR, ENGINE_LIGHTGBM, ENGINE_MXNET, ENGINE_ONNX_RUNTIME, ENGINE_PADDLEPADDLE, |
@@ -308,29 +316,66 @@ private static ZooModel<NDList, NDList> loadModel(String engineName, String urls |
308 | 316 | .optModelUrls(urls) |
309 | 317 | .optProgress(new ProgressBar()); |
310 | 318 |
|
311 | | - boolean foundEngine = false; |
| 319 | + String selectedEngine = null; |
312 | 320 | if (engineName != null) { |
313 | 321 | if (Engine.getAllEngines().contains(engineName)) { |
314 | | - builder = builder.optEngine(engineName); |
315 | | - foundEngine = true; |
| 322 | + selectedEngine = engineName; |
316 | 323 | } |
317 | 324 | } |
318 | 325 |
|
319 | 326 | // Try to figure out the engine name |
320 | | - if (!foundEngine) { |
| 327 | + if (selectedEngine == null) { |
321 | 328 | var urlString = urls.toString().toLowerCase(); |
322 | 329 | if (urlString.endsWith(".onnx") && Engine.hasEngine("OnnxRuntime")) |
323 | | - builder = builder.optEngine("OnnxRuntime"); |
| 330 | + selectedEngine = "OnnxRuntime"; |
324 | 331 | else if ((urlString.endsWith("pytorch") || urlString.endsWith(".pt")) && Engine.hasEngine("PyTorch")) |
325 | | - builder = builder.optEngine("PyTorch"); |
| 332 | + selectedEngine = "PyTorch"; |
326 | 333 | else if ((urlString.endsWith(".pb") || urlString.endsWith("tf_savedmodel.zip") || urlString.endsWith("tf_savedmodel")) && Engine.hasEngine("TensorFlow")) |
327 | | - builder = builder.optEngine("TensorFlow"); |
| 334 | + selectedEngine = "TensorFlow"; |
| 335 | + } |
| 336 | + |
| 337 | + if (selectedEngine != null) { |
| 338 | + builder.optEngine(selectedEngine); |
| 339 | + var device = defaultDevices.getOrDefault(selectedEngine, null); |
| 340 | + if (device != null) { |
| 341 | + builder.optDevice(device); |
| 342 | + builder.optOption("mapLocation", "true"); |
| 343 | + } |
328 | 344 | } |
329 | 345 |
|
330 | 346 | var criteria = builder.build(); |
331 | 347 | return ModelZoo.loadModel(criteria); |
332 | 348 | } |
333 | 349 |
|
| 350 | + |
| 351 | + /** |
| 352 | + * Set the default device for the specified engine. |
| 353 | + * This will be used only whenever the model is build using this class, overriding |
| 354 | + * DJL's default. |
| 355 | + * <p> |
| 356 | + * Note that the default device chosen automatically by DJL is usually fine, |
| 357 | + * and so it is generally not necessary to set this. |
| 358 | + * However it can be useful for exploring, or if DJL does not use the device you want. |
| 359 | + * @param engineName |
| 360 | + * @param device |
| 361 | + */ |
| 362 | + public static void setOverrideDevice(String engineName, Device device) { |
| 363 | + if (device == null) |
| 364 | + defaultDevices.remove(engineName); |
| 365 | + else |
| 366 | + defaultDevices.put(engineName, device); |
| 367 | + } |
| 368 | + |
| 369 | + /** |
| 370 | + * Get the default device for the specified engine, which overrides DJL's default device for |
| 371 | + * the specified engine. |
| 372 | + * @param engineName |
| 373 | + * @return the default device, or null if not set |
| 374 | + */ |
| 375 | + public static Device getOverrideDevice(String engineName) { |
| 376 | + return defaultDevices.getOrDefault(engineName, null); |
| 377 | + } |
| 378 | + |
334 | 379 | // static ZooModel<Mat, Mat> loadModelCV(URI uri, String ndLayout) throws ModelNotFoundException, MalformedModelException, IOException { |
335 | 380 | // var criteria = Criteria.builder() |
336 | 381 | // .setTypes(Mat.class, Mat.class) |
@@ -373,9 +418,13 @@ public static NDArray matToNDArray(NDManager manager, Mat mat, String ndLayout) |
373 | 418 | var buffer = mat.createBuffer(); |
374 | 419 | array = manager.create(buffer, shape, dataType); |
375 | 420 | } else { |
| 421 | + var shapeDims = shape.getShape(); |
| 422 | + shapeDims[indC] = 1; |
| 423 | + var shapeChannel = new Shape(shapeDims, shape.getLayout()); |
| 424 | + |
376 | 425 | for (var mat2 : OpenCVTools.splitChannels(mat)) { |
377 | | - var buffer = mat2.createBuffer(); |
378 | | - var arrayTemp = manager.create(buffer, shape, dataType); |
| 426 | + var buffer = mat2.createBuffer(); |
| 427 | + var arrayTemp = manager.create(buffer, shapeChannel, dataType); |
379 | 428 | if (array == null) |
380 | 429 | array = arrayTemp; |
381 | 430 | else { |
@@ -469,26 +518,112 @@ public static Mat ndArrayToMat(NDArray array, String ndLayout, boolean doSqueeze |
469 | 518 | if (indexer instanceof ByteIndexer) { |
470 | 519 | ((ByteIndexer) indexer).put(0L, array.toByteArray()); |
471 | 520 | } else if (indexer instanceof UByteIndexer) { |
472 | | - ((UByteIndexer) indexer).put(0L, array.toUint8Array()); |
| 521 | + ((UByteIndexer) indexer).put(0L, getInts(array)); |
473 | 522 | } else if (indexer instanceof UShortIndexer) { |
474 | | - ((UShortIndexer) indexer).put(0L, array.toIntArray()); |
| 523 | + ((UShortIndexer) indexer).put(0L, getInts(array)); |
475 | 524 | } else if (indexer instanceof IntIndexer) { |
476 | | - ((IntIndexer) indexer).put(0L, array.toIntArray()); |
| 525 | + ((IntIndexer) indexer).put(0L, getInts(array)); |
477 | 526 | } else if (indexer instanceof FloatIndexer) { |
478 | | - ((FloatIndexer) indexer).put(0L, array.toFloatArray()); |
| 527 | + ((FloatIndexer) indexer).put(0L, getFloats(array)); |
479 | 528 | } else if (indexer instanceof HalfIndexer) { |
480 | | - ((HalfIndexer) indexer).put(0L, array.toFloatArray()); |
| 529 | + ((HalfIndexer) indexer).put(0L,getFloats(array)); |
481 | 530 | } else if (indexer instanceof DoubleIndexer) { |
482 | | - ((DoubleIndexer) indexer).put(0L, array.toDoubleArray()); |
| 531 | + ((DoubleIndexer) indexer).put(0L, getDoubles(array)); |
483 | 532 | } else if (indexer instanceof LongIndexer) { |
484 | | - ((LongIndexer) indexer).put(0L, array.toLongArray()); |
| 533 | + ((LongIndexer) indexer).put(0L, getLongs(array)); |
485 | 534 | } else if (indexer instanceof BooleanIndexer) { |
486 | | - ((BooleanIndexer) indexer).put(0L, array.toBooleanArray()); |
| 535 | + ((BooleanIndexer) indexer).put(0L, getBooleans(array)); |
487 | 536 | } else |
488 | 537 | throw new IllegalArgumentException("Unable to convert array " + array + " to Mat"); |
489 | 538 | } |
490 | 539 | return mat; |
491 | 540 | } |
| 541 | + |
| 542 | + /** |
| 543 | + * Extract array values as longs, converting if necessary. |
| 544 | + * @param array |
| 545 | + * @return |
| 546 | + */ |
| 547 | + public static long[] getLongs(NDArray array) { |
| 548 | + if (array.getDataType() == DataType.INT64) { |
| 549 | + try { |
| 550 | + return array.toLongArray(); |
| 551 | + } catch (Exception e) { |
| 552 | + logger.error("Exception requesting longs from NDArray"); |
| 553 | + } |
| 554 | + } |
| 555 | + return array.toType(DataType.INT64, true).toLongArray(); |
| 556 | + } |
| 557 | + |
| 558 | + /** |
| 559 | + * Extract array values as booleans, converting if necessary. |
| 560 | + * @param array |
| 561 | + * @return |
| 562 | + */ |
| 563 | + private static boolean[] getBooleans(NDArray array) { |
| 564 | + if (array.getDataType() == DataType.BOOLEAN) { |
| 565 | + try { |
| 566 | + return array.toBooleanArray(); |
| 567 | + } catch (Exception e) { |
| 568 | + logger.error("Exception requesting ints from NDArray"); |
| 569 | + } |
| 570 | + } |
| 571 | + return array.toType(DataType.BOOLEAN, true).toBooleanArray(); |
| 572 | + } |
| 573 | + |
| 574 | + /** |
| 575 | + * Extract array values as ints, converting if necessary. |
| 576 | + * @param array |
| 577 | + * @return |
| 578 | + */ |
| 579 | + private static int[] getInts(NDArray array) { |
| 580 | + if (array.getDataType() == DataType.INT32) { |
| 581 | + try { |
| 582 | + return array.toIntArray(); |
| 583 | + } catch (Exception e) { |
| 584 | + logger.error("Exception requesting ints from NDArray"); |
| 585 | + } |
| 586 | + } else if (array.getDataType() == DataType.UINT8) { |
| 587 | + try { |
| 588 | + return array.toUint8Array(); |
| 589 | + } catch (Exception e) { |
| 590 | + logger.error("Exception requesting ints from NDArray"); |
| 591 | + } |
| 592 | + } |
| 593 | + return array.toType(DataType.INT32, true).toIntArray(); |
| 594 | + } |
| 595 | + |
| 596 | + /** |
| 597 | + * Extract array values as doubles, converting if necessary. |
| 598 | + * @param array |
| 599 | + * @return |
| 600 | + */ |
| 601 | + private static double[] getDoubles(NDArray array) { |
| 602 | + if (array.getDataType() == DataType.FLOAT64) { |
| 603 | + try { |
| 604 | + return array.toDoubleArray(); |
| 605 | + } catch (Exception e) { |
| 606 | + logger.error("Exception requesting doubles from NDArray"); |
| 607 | + } |
| 608 | + } |
| 609 | + return array.toType(DataType.FLOAT64, true).toDoubleArray(); |
| 610 | + } |
| 611 | + |
| 612 | + /** |
| 613 | + * Extract array values as floats, converting if necessary. |
| 614 | + * @param array |
| 615 | + * @return |
| 616 | + */ |
| 617 | + private static float[] getFloats(NDArray array) { |
| 618 | + if (array.getDataType() == DataType.FLOAT32 || array.getDataType() == DataType.FLOAT16) { |
| 619 | + try { |
| 620 | + return array.toFloatArray(); |
| 621 | + } catch (Exception e) { |
| 622 | + logger.error("Exception requesting floats from NDArray", e); |
| 623 | + } |
| 624 | + } |
| 625 | + return array.toType(DataType.FLOAT32, true).toFloatArray(); |
| 626 | + } |
492 | 627 |
|
493 | 628 |
|
494 | 629 | static class MatTranslator implements Translator<Mat, Mat> { |
|
0 commit comments