|
| 1 | +#!/usr/bin/env python |
| 2 | +# |
| 3 | +# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file |
| 6 | +# except in compliance with the License. A copy of the License is located at |
| 7 | +# |
| 8 | +# http://aws.amazon.com/apache2.0/ |
| 9 | +# |
| 10 | +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" |
| 11 | +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for |
| 12 | +# the specific language governing permissions and limitations under the License. |
| 13 | + |
| 14 | +import pickle |
| 15 | +import numpy as np |
| 16 | +import os |
| 17 | +from io import StringIO |
| 18 | +from typing import Optional |
| 19 | +from djl_python import Input, Output |
| 20 | +from djl_python.encode_decode import decode |
| 21 | +from djl_python.utils import find_model_file |
| 22 | +from djl_python.service_loader import get_annotated_function |
| 23 | +from djl_python.import_utils import joblib, cloudpickle, skops_io as sio |
| 24 | + |
| 25 | + |
| 26 | +class SklearnHandler: |
| 27 | + |
| 28 | + def __init__(self): |
| 29 | + self.model = None |
| 30 | + self.initialized = False |
| 31 | + self.custom_input_formatter = None |
| 32 | + self.custom_output_formatter = None |
| 33 | + self.custom_predict_formatter = None |
| 34 | + |
| 35 | + def _get_trusted_types(self, properties: dict): |
| 36 | + trusted_types_str = properties.get("skops_trusted_types", "") |
| 37 | + if not trusted_types_str: |
| 38 | + raise ValueError( |
| 39 | + "option.skops_trusted_types must be set to load skops models. " |
| 40 | + "Example: option.skops_trusted_types='sklearn.ensemble._forest.RandomForestClassifier,numpy.ndarray'" |
| 41 | + ) |
| 42 | + trusted_types = [ |
| 43 | + t.strip() for t in trusted_types_str.split(",") if t.strip() |
| 44 | + ] |
| 45 | + print(f"Using trusted types for skops model loading: {trusted_types}") |
| 46 | + return trusted_types |
| 47 | + |
| 48 | + def initialize(self, properties: dict): |
| 49 | + model_dir = properties.get("model_dir") |
| 50 | + model_format = properties.get("model_format", "skops") |
| 51 | + |
| 52 | + format_extensions = { |
| 53 | + "skops": ["skops"], |
| 54 | + "joblib": ["joblib", "jl"], |
| 55 | + "pickle": ["pkl", "pickle"], |
| 56 | + "cloudpickle": ["pkl", "pickle", "cloudpkl"] |
| 57 | + } |
| 58 | + |
| 59 | + extensions = format_extensions.get(model_format) |
| 60 | + if not extensions: |
| 61 | + raise ValueError( |
| 62 | + f"Unsupported model format: {model_format}. Supported formats: skops, joblib, pickle, cloudpickle" |
| 63 | + ) |
| 64 | + |
| 65 | + model_file = find_model_file(model_dir, extensions) |
| 66 | + if not model_file: |
| 67 | + raise FileNotFoundError( |
| 68 | + f"No model file found with format '{model_format}' in {model_dir}" |
| 69 | + ) |
| 70 | + |
| 71 | + if model_format == "skops": |
| 72 | + trusted_types = self._get_trusted_types(properties) |
| 73 | + self.model = sio.load(model_file, trusted=trusted_types) |
| 74 | + else: |
| 75 | + if properties.get("trust_insecure_model_files", |
| 76 | + "false").lower() != "true": |
| 77 | + raise ValueError( |
| 78 | + f"option.trust_insecure_model_files must be set to 'true' to use {model_format} format (only skops is secure by default)" |
| 79 | + ) |
| 80 | + |
| 81 | + if model_format == "joblib": |
| 82 | + self.model = joblib.load(model_file) |
| 83 | + elif model_format == "pickle": |
| 84 | + with open(model_file, 'rb') as f: |
| 85 | + self.model = pickle.load(f) |
| 86 | + elif model_format == "cloudpickle": |
| 87 | + with open(model_file, 'rb') as f: |
| 88 | + self.model = cloudpickle.load(f) |
| 89 | + |
| 90 | + self.custom_input_formatter = get_annotated_function( |
| 91 | + model_dir, "is_input_formatter") |
| 92 | + self.custom_output_formatter = get_annotated_function( |
| 93 | + model_dir, "is_output_formatter") |
| 94 | + self.custom_predict_formatter = get_annotated_function( |
| 95 | + model_dir, "is_predict_formatter") |
| 96 | + |
| 97 | + self.initialized = True |
| 98 | + |
| 99 | + def inference(self, inputs: Input) -> Output: |
| 100 | + content_type = inputs.get_property("Content-Type") |
| 101 | + accept = inputs.get_property("Accept") or "application/json" |
| 102 | + |
| 103 | + # Validate accept type (skip validation if custom output formatter is provided) |
| 104 | + if not self.custom_output_formatter: |
| 105 | + supported_accept_types = ["application/json", "text/csv"] |
| 106 | + if not any(supported_type in accept |
| 107 | + for supported_type in supported_accept_types): |
| 108 | + raise ValueError( |
| 109 | + f"Unsupported Accept type: {accept}. Supported types: {supported_accept_types}" |
| 110 | + ) |
| 111 | + |
| 112 | + # Input processing |
| 113 | + X = None |
| 114 | + if self.custom_input_formatter: |
| 115 | + X = self.custom_input_formatter(inputs) |
| 116 | + elif "text/csv" in content_type: |
| 117 | + X = decode(inputs, content_type, require_csv_headers=False) |
| 118 | + else: |
| 119 | + input_map = decode(inputs, content_type) |
| 120 | + data = input_map.get("inputs") if isinstance(input_map, |
| 121 | + dict) else input_map |
| 122 | + X = np.array(data) |
| 123 | + |
| 124 | + if X is None or not hasattr(X, 'ndim'): |
| 125 | + raise ValueError( |
| 126 | + f"Input processing failed for content type {content_type}") |
| 127 | + |
| 128 | + if X.ndim == 1: |
| 129 | + X = X.reshape(1, -1) |
| 130 | + |
| 131 | + if self.custom_predict_formatter: |
| 132 | + predictions = self.custom_predict_formatter(self.model, X) |
| 133 | + else: |
| 134 | + predictions = self.model.predict(X) |
| 135 | + |
| 136 | + # Output processing |
| 137 | + if self.custom_output_formatter: |
| 138 | + return self.custom_output_formatter(predictions) |
| 139 | + |
| 140 | + # Supports CSV/JSON outputs by default |
| 141 | + outputs = Output() |
| 142 | + if "text/csv" in accept: |
| 143 | + csv_buffer = StringIO() |
| 144 | + np.savetxt(csv_buffer, predictions, fmt='%s', delimiter=',') |
| 145 | + outputs.add(csv_buffer.getvalue().rstrip()) |
| 146 | + outputs.add_property("Content-Type", "text/csv") |
| 147 | + else: |
| 148 | + outputs.add_as_json({"predictions": predictions.tolist()}) |
| 149 | + return outputs |
| 150 | + |
| 151 | + |
| 152 | +service = SklearnHandler() |
| 153 | + |
| 154 | + |
| 155 | +def handle(inputs: Input) -> Optional[Output]: |
| 156 | + if not service.initialized: |
| 157 | + service.initialize(inputs.get_properties()) |
| 158 | + |
| 159 | + if inputs.is_empty(): |
| 160 | + return None |
| 161 | + |
| 162 | + return service.inference(inputs) |
0 commit comments