Skip to content

Commit 3e49622

Browse files
authored
Add XGB & SKL Py handlers with CSV/json support (deepjavalibrary#2906)
1 parent 43e8126 commit 3e49622

File tree

11 files changed

+1172
-19
lines changed

11 files changed

+1172
-19
lines changed

engines/python/setup/djl_python/encode_decode.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,30 @@
2121
import numpy as np
2222

2323

24-
def decode_csv(inputs: Input): # type: (str) -> np.array
24+
def decode_csv(inputs: Input, require_headers=True): # type: (str) -> np.array
2525
csv_content = inputs.get_as_string()
26-
stream = StringIO(csv_content)
27-
# detects if the incoming csv has headers
28-
if not any(header in csv_content.splitlines()[0].lower()
29-
for header in ["question", "context", "inputs"]):
30-
raise ValueError(
31-
"You need to provide the correct CSV with Header columns to use it with the inference toolkit default handler.",
32-
)
33-
# reads csv as io
34-
request_list = list(csv.DictReader(stream))
35-
if "inputs" in request_list[0].keys():
36-
return {"inputs": [entry["inputs"] for entry in request_list]}
26+
27+
if require_headers:
28+
if not any(header in csv_content.splitlines()[0].lower()
29+
for header in ["question", "context", "inputs"]):
30+
raise ValueError(
31+
"You need to provide the correct CSV with Header columns to use it with the inference toolkit default handler.",
32+
)
33+
stream = StringIO(csv_content)
34+
request_list = list(csv.DictReader(stream))
35+
if "inputs" in request_list[0].keys():
36+
return {"inputs": [entry["inputs"] for entry in request_list]}
37+
else:
38+
return {"inputs": request_list}
3739
else:
38-
return {"inputs": request_list}
40+
# for preditive ML inputs
41+
result = np.genfromtxt(StringIO(csv_content), delimiter=",")
42+
# Check for NaN values which indicate non-numeric data
43+
if np.isnan(result).any():
44+
raise ValueError(
45+
"CSV contains non-numeric data. Please provide numeric data only."
46+
)
47+
return result
3948

4049

4150
def encode_csv(content): # type: (str) -> np.array
@@ -51,7 +60,10 @@ def encode_csv(content): # type: (str) -> np.array
5160
return stream.getvalue()
5261

5362

54-
def decode(inputs: Input, content_type: str, key=None):
63+
def decode(inputs: Input,
64+
content_type: str,
65+
key=None,
66+
require_csv_headers=True):
5567
if not content_type:
5668
ret = inputs.get_as_bytes(key=key)
5769
if not ret:
@@ -60,7 +72,7 @@ def decode(inputs: Input, content_type: str, key=None):
6072
elif "application/json" in content_type:
6173
return inputs.get_as_json(key=key)
6274
elif "text/csv" in content_type:
63-
return decode_csv(inputs)
75+
return decode_csv(inputs, require_headers=require_csv_headers)
6476
elif "text/plain" in content_type:
6577
return {"inputs": [inputs.get_as_string(key=key)]}
6678
if content_type.startswith("image/"):
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
6+
# except in compliance with the License. A copy of the License is located at
7+
#
8+
# http://aws.amazon.com/apache2.0/
9+
#
10+
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
11+
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
12+
# the specific language governing permissions and limitations under the License.
13+
14+
import importlib.util
15+
import importlib.metadata
16+
17+
18+
def _is_package_available(pkg_name: str) -> bool:
19+
"""Check if a package is available"""
20+
package_exists = importlib.util.find_spec(pkg_name) is not None
21+
if package_exists:
22+
try:
23+
importlib.metadata.version(pkg_name)
24+
except importlib.metadata.PackageNotFoundError:
25+
package_exists = False
26+
return package_exists
27+
28+
29+
# SKLearn model persistance libraries
30+
_joblib_available = _is_package_available("joblib")
31+
_cloudpickle_available = _is_package_available("cloudpickle")
32+
_skops_available = _is_package_available("skops")
33+
34+
# XGBoost
35+
_xgboost_available = _is_package_available("xgboost")
36+
37+
38+
def is_joblib_available() -> bool:
39+
return _joblib_available
40+
41+
42+
def is_cloudpickle_available() -> bool:
43+
return _cloudpickle_available
44+
45+
46+
def is_skops_available() -> bool:
47+
return _skops_available
48+
49+
50+
def is_xgboost_available() -> bool:
51+
return _xgboost_available
52+
53+
54+
joblib = None
55+
if _joblib_available:
56+
import joblib
57+
58+
cloudpickle = None
59+
if _cloudpickle_available:
60+
import cloudpickle
61+
62+
skops_io = None
63+
if _skops_available:
64+
import skops.io as skops_io
65+
66+
xgboost = None
67+
if _xgboost_available:
68+
import xgboost
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
6+
# except in compliance with the License. A copy of the License is located at
7+
#
8+
# http://aws.amazon.com/apache2.0/
9+
#
10+
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
11+
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
12+
# the specific language governing permissions and limitations under the License.
13+
14+
import pickle
15+
import numpy as np
16+
import os
17+
from io import StringIO
18+
from typing import Optional
19+
from djl_python import Input, Output
20+
from djl_python.encode_decode import decode
21+
from djl_python.utils import find_model_file
22+
from djl_python.service_loader import get_annotated_function
23+
from djl_python.import_utils import joblib, cloudpickle, skops_io as sio
24+
25+
26+
class SklearnHandler:
27+
28+
def __init__(self):
29+
self.model = None
30+
self.initialized = False
31+
self.custom_input_formatter = None
32+
self.custom_output_formatter = None
33+
self.custom_predict_formatter = None
34+
35+
def _get_trusted_types(self, properties: dict):
36+
trusted_types_str = properties.get("skops_trusted_types", "")
37+
if not trusted_types_str:
38+
raise ValueError(
39+
"option.skops_trusted_types must be set to load skops models. "
40+
"Example: option.skops_trusted_types='sklearn.ensemble._forest.RandomForestClassifier,numpy.ndarray'"
41+
)
42+
trusted_types = [
43+
t.strip() for t in trusted_types_str.split(",") if t.strip()
44+
]
45+
print(f"Using trusted types for skops model loading: {trusted_types}")
46+
return trusted_types
47+
48+
def initialize(self, properties: dict):
49+
model_dir = properties.get("model_dir")
50+
model_format = properties.get("model_format", "skops")
51+
52+
format_extensions = {
53+
"skops": ["skops"],
54+
"joblib": ["joblib", "jl"],
55+
"pickle": ["pkl", "pickle"],
56+
"cloudpickle": ["pkl", "pickle", "cloudpkl"]
57+
}
58+
59+
extensions = format_extensions.get(model_format)
60+
if not extensions:
61+
raise ValueError(
62+
f"Unsupported model format: {model_format}. Supported formats: skops, joblib, pickle, cloudpickle"
63+
)
64+
65+
model_file = find_model_file(model_dir, extensions)
66+
if not model_file:
67+
raise FileNotFoundError(
68+
f"No model file found with format '{model_format}' in {model_dir}"
69+
)
70+
71+
if model_format == "skops":
72+
trusted_types = self._get_trusted_types(properties)
73+
self.model = sio.load(model_file, trusted=trusted_types)
74+
else:
75+
if properties.get("trust_insecure_model_files",
76+
"false").lower() != "true":
77+
raise ValueError(
78+
f"option.trust_insecure_model_files must be set to 'true' to use {model_format} format (only skops is secure by default)"
79+
)
80+
81+
if model_format == "joblib":
82+
self.model = joblib.load(model_file)
83+
elif model_format == "pickle":
84+
with open(model_file, 'rb') as f:
85+
self.model = pickle.load(f)
86+
elif model_format == "cloudpickle":
87+
with open(model_file, 'rb') as f:
88+
self.model = cloudpickle.load(f)
89+
90+
self.custom_input_formatter = get_annotated_function(
91+
model_dir, "is_input_formatter")
92+
self.custom_output_formatter = get_annotated_function(
93+
model_dir, "is_output_formatter")
94+
self.custom_predict_formatter = get_annotated_function(
95+
model_dir, "is_predict_formatter")
96+
97+
self.initialized = True
98+
99+
def inference(self, inputs: Input) -> Output:
100+
content_type = inputs.get_property("Content-Type")
101+
accept = inputs.get_property("Accept") or "application/json"
102+
103+
# Validate accept type (skip validation if custom output formatter is provided)
104+
if not self.custom_output_formatter:
105+
supported_accept_types = ["application/json", "text/csv"]
106+
if not any(supported_type in accept
107+
for supported_type in supported_accept_types):
108+
raise ValueError(
109+
f"Unsupported Accept type: {accept}. Supported types: {supported_accept_types}"
110+
)
111+
112+
# Input processing
113+
X = None
114+
if self.custom_input_formatter:
115+
X = self.custom_input_formatter(inputs)
116+
elif "text/csv" in content_type:
117+
X = decode(inputs, content_type, require_csv_headers=False)
118+
else:
119+
input_map = decode(inputs, content_type)
120+
data = input_map.get("inputs") if isinstance(input_map,
121+
dict) else input_map
122+
X = np.array(data)
123+
124+
if X is None or not hasattr(X, 'ndim'):
125+
raise ValueError(
126+
f"Input processing failed for content type {content_type}")
127+
128+
if X.ndim == 1:
129+
X = X.reshape(1, -1)
130+
131+
if self.custom_predict_formatter:
132+
predictions = self.custom_predict_formatter(self.model, X)
133+
else:
134+
predictions = self.model.predict(X)
135+
136+
# Output processing
137+
if self.custom_output_formatter:
138+
return self.custom_output_formatter(predictions)
139+
140+
# Supports CSV/JSON outputs by default
141+
outputs = Output()
142+
if "text/csv" in accept:
143+
csv_buffer = StringIO()
144+
np.savetxt(csv_buffer, predictions, fmt='%s', delimiter=',')
145+
outputs.add(csv_buffer.getvalue().rstrip())
146+
outputs.add_property("Content-Type", "text/csv")
147+
else:
148+
outputs.add_as_json({"predictions": predictions.tolist()})
149+
return outputs
150+
151+
152+
service = SklearnHandler()
153+
154+
155+
def handle(inputs: Input) -> Optional[Output]:
156+
if not service.initialized:
157+
service.initialize(inputs.get_properties())
158+
159+
if inputs.is_empty():
160+
return None
161+
162+
return service.inference(inputs)

engines/python/setup/djl_python/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
1111
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
1212
# the specific language governing permissions and limitations under the License.
13+
import glob
1314
import logging
15+
import os
16+
from typing import Optional, List
1417

1518
from djl_python import Output
1619
from djl_python.inputs import Input
@@ -161,3 +164,27 @@ def get_input_details(requests, errors, batch):
161164
idx += 1
162165
adapters = adapters if adapters else None
163166
return input_data, input_size, parameters, adapters
167+
168+
169+
def find_model_file(model_dir: str, extensions: List[str]) -> Optional[str]:
170+
"""Find model file with given extensions in model directory
171+
172+
Args:
173+
model_dir: Directory to search for model files
174+
extensions: List of file extensions to search for (without dots)
175+
176+
Returns:
177+
Path to matching model file, or None if not found
178+
"""
179+
all_matches = []
180+
for ext in extensions:
181+
pattern = os.path.join(model_dir, f"*.{ext}")
182+
matches = glob.glob(pattern)
183+
all_matches.extend(matches)
184+
185+
if len(all_matches) > 1:
186+
raise ValueError(
187+
f"Multiple model files found in {model_dir}: {all_matches}. Only one model file is supported per directory."
188+
)
189+
190+
return all_matches[0] if all_matches else None

0 commit comments

Comments
 (0)