Skip to content

Commit 1032c47

Browse files
authored
Auto detect handler for sklearn (deepjavalibrary#2917)
1 parent 3e49622 commit 1032c47

File tree

3 files changed

+36
-12
lines changed

3 files changed

+36
-12
lines changed

engines/python/setup/djl_python/ts_service_loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,9 @@ def invoke_handler(self, function_name, inputs):
7676
if content_type is not None:
7777
outputs.add_property("content-type", content_type)
7878

79-
if ts_out is None:
79+
if not ts_out:
8080
outputs.message = "No content"
81+
output.add("No content")
8182
else:
8283
val = ts_out[0]
8384
if isinstance(val, torch.Tensor):

engines/python/src/main/java/ai/djl/python/engine/PyModel.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,10 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
179179
entryPoint = modelFile.toFile().getName();
180180
}
181181
// find recommendedEntryPoint
182-
if ("nc".equals(manager.getDevice().getDeviceType())
182+
if (hasModelFile(
183+
modelDir, prefix, ".skops", ".joblib", ".pkl", ".pickle", ".cloudpkl")) {
184+
recommendedEntryPoint = "djl_python.sklearn_handler";
185+
} else if ("nc".equals(manager.getDevice().getDeviceType())
183186
&& pyEnv.getTensorParallelDegree() > 0) {
184187
recommendedEntryPoint = "djl_python.transformers_neuronx";
185188
} else if ("trtllm".equals(features)) {
@@ -327,6 +330,18 @@ private Path findModelFile(String prefix) {
327330
return modelFile;
328331
}
329332

333+
private boolean hasModelFile(Path modelDir, String prefix, String... extensions) {
334+
for (String extension : extensions) {
335+
if (Files.isRegularFile(modelDir.resolve(prefix + extension))) {
336+
return true;
337+
}
338+
if (Files.isRegularFile(modelDir.resolve("model" + extension))) {
339+
return true;
340+
}
341+
}
342+
return false;
343+
}
344+
330345
private void createAllPyProcesses(int mpiWorkers, int worldSize) {
331346
long begin = System.currentTimeMillis();
332347
ExecutorService pool = null;

wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,18 @@ public Device withDefaultDevice(String deviceName) {
742742
return Device.fromName(deviceName, Engine.getEngine(engineName));
743743
}
744744

745+
private boolean hasModelFile(Path modelDir, String prefix, String... extensions) {
746+
for (String extension : extensions) {
747+
if (Files.isRegularFile(modelDir.resolve(prefix + extension))) {
748+
return true;
749+
}
750+
if (Files.isRegularFile(modelDir.resolve("model" + extension))) {
751+
return true;
752+
}
753+
}
754+
return false;
755+
}
756+
745757
private String inferEngine() throws ModelException {
746758
String eng = prop.getProperty("engine");
747759
if (eng != null) {
@@ -768,23 +780,19 @@ private String inferEngine() throws ModelException {
768780
return zoo.getSupportedEngines().iterator().next();
769781
} else if (isTorchServeModel()) {
770782
return "Python";
771-
} else if (Files.isRegularFile(modelDir.resolve(prefix + ".pt"))
772-
|| Files.isRegularFile(modelDir.resolve("model.pt"))) {
783+
} else if (hasModelFile(modelDir, prefix, ".pt")) {
773784
return "PyTorch";
774785
} else if (Files.isRegularFile(modelDir.resolve("config.pbtxt"))) {
775786
return "TritonServer";
776787
} else if (Files.isRegularFile(modelDir.resolve("saved_model.pb"))) {
777788
return "TensorFlow";
778-
} else if (Files.isRegularFile(modelDir.resolve(prefix + ".onnx"))
779-
|| Files.isRegularFile(modelDir.resolve("model.onnx"))) {
789+
} else if (hasModelFile(modelDir, prefix, ".onnx")) {
780790
return "OnnxRuntime";
781-
} else if (Files.isRegularFile(modelDir.resolve(prefix + ".json"))
782-
|| Files.isRegularFile(modelDir.resolve(prefix + ".xgb"))
783-
|| Files.isRegularFile(modelDir.resolve(prefix + ".bst"))
784-
|| Files.isRegularFile(modelDir.resolve("model.json"))
785-
|| Files.isRegularFile(modelDir.resolve("model.bst"))
786-
|| Files.isRegularFile(modelDir.resolve("model.xgb"))) {
791+
} else if (hasModelFile(modelDir, prefix, ".json", ".xgb", ".bst")) {
787792
return "XGBoost";
793+
} else if (hasModelFile(
794+
modelDir, prefix, ".skops", ".joblib", ".pkl", ".pickle", ".cloudpkl")) {
795+
return "Python";
788796
} else if (isPythonModel(prefix)) {
789797
// TODO: How to differentiate Rust model from Python
790798
return "Python";

0 commit comments

Comments
 (0)