|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +""" |
| 4 | +.. _example-catboost: |
| 5 | +
|
| 6 | +Convert a pipeline with a CatBoost classifier |
| 7 | +============================================= |
| 8 | +
|
| 9 | +.. index:: CatBoost |
| 10 | +
|
| 11 | +:epkg:`sklearn-onnx` only converts :epkg:`scikit-learn` models into *ONNX* |
| 12 | +but many libraries implement :epkg:`scikit-learn` API so that their models |
| 13 | +can be included in a :epkg:`scikit-learn` pipeline. This example considers |
| 14 | +a pipeline including a :epkg:`CatBoost` model. :epkg:`sklearn-onnx` can convert |
| 15 | +the whole pipeline as long as it knows the converter associated to |
| 16 | +a *CatBoostClassifier*. Let's see how to do it. |
| 17 | +
|
| 18 | +.. contents:: |
| 19 | + :local: |
| 20 | +
|
| 21 | +Train a CatBoostClassifier |
| 22 | +++++++++++++++++++++++++++ |
| 23 | +""" |
| 24 | +from pyquickhelper.helpgen.graphviz_helper import plot_graphviz |
| 25 | +import numpy |
| 26 | +from onnx.helper import get_attribute_value |
| 27 | +from sklearn.datasets import load_iris |
| 28 | +from sklearn.pipeline import Pipeline |
| 29 | +from sklearn.preprocessing import StandardScaler |
| 30 | +from mlprodict.onnxrt import OnnxInference |
| 31 | +import onnxruntime as rt |
| 32 | +from skl2onnx import convert_sklearn, update_registered_converter |
| 33 | +from skl2onnx.common.shape_calculator import calculate_linear_classifier_output_shapes # noqa |
| 34 | +from skl2onnx.common.data_types import FloatTensorType, Int64TensorType, guess_tensor_type |
| 35 | +from skl2onnx._parse import _apply_zipmap, _get_sklearn_operator_name |
| 36 | +from catboost import CatBoostClassifier |
| 37 | +from catboost.utils import convert_to_onnx_object |
| 38 | + |
| 39 | +data = load_iris() |
| 40 | +X = data.data[:, :2] |
| 41 | +y = data.target |
| 42 | + |
| 43 | +ind = numpy.arange(X.shape[0]) |
| 44 | +numpy.random.shuffle(ind) |
| 45 | +X = X[ind, :].copy() |
| 46 | +y = y[ind].copy() |
| 47 | + |
| 48 | +pipe = Pipeline([('scaler', StandardScaler()), |
| 49 | + ('lgbm', CatBoostClassifier(n_estimators=3))]) |
| 50 | +pipe.fit(X, y) |
| 51 | + |
| 52 | +###################################### |
| 53 | +# Register the converter for CatBoostClassifier |
| 54 | +# +++++++++++++++++++++++++++++++++++++++++++++ |
| 55 | +# |
| 56 | +# The model has no converter implemented in sklearn-onnx. |
| 57 | +# We need to register the one coming from *CatBoost* itself. |
| 58 | +# However, the converter does not follow sklearn-onnx design and |
| 59 | +# needs to be wrapped. |
| 60 | + |
| 61 | + |
| 62 | +def skl2onnx_parser_castboost_classifier(scope, model, inputs, |
| 63 | + custom_parsers=None): |
| 64 | + options = scope.get_options(model, dict(zipmap=True)) |
| 65 | + no_zipmap = isinstance(options['zipmap'], bool) and not options['zipmap'] |
| 66 | + |
| 67 | + alias = _get_sklearn_operator_name(type(model)) |
| 68 | + this_operator = scope.declare_local_operator(alias, model) |
| 69 | + this_operator.inputs = inputs |
| 70 | + |
| 71 | + label_variable = scope.declare_local_variable('label', Int64TensorType()) |
| 72 | + prob_dtype = guess_tensor_type(inputs[0].type) |
| 73 | + probability_tensor_variable = scope.declare_local_variable('probabilities', prob_dtype) |
| 74 | + this_operator.outputs.append(label_variable) |
| 75 | + this_operator.outputs.append(probability_tensor_variable) |
| 76 | + probability_tensor = this_operator.outputs |
| 77 | + |
| 78 | + if no_zipmap: |
| 79 | + return probability_tensor |
| 80 | + |
| 81 | + return _apply_zipmap(options['zipmap'], scope, model, |
| 82 | + inputs[0].type, probability_tensor) |
| 83 | + |
| 84 | + |
| 85 | +def skl2onnx_convert_catboost(scope, operator, container): |
| 86 | + """ |
| 87 | + CatBoost returns an ONNX graph with a single node. |
| 88 | + This function adds it to the main graph. |
| 89 | + """ |
| 90 | + onx = convert_to_onnx_object(operator.raw_operator) |
| 91 | + opsets = {d.domain: d.version for d in onx.opset_import} |
| 92 | + if '' in opsets and opsets[''] >= container.target_opset: |
| 93 | + raise RuntimeError( |
| 94 | + "CatBoost uses an opset more recent than the target one.") |
| 95 | + if len(onx.graph.initializer) > 0 or len(onx.graph.sparse_initializer) > 0: |
| 96 | + raise NotImplementedError( |
| 97 | + "CatBoost returns a model initializers. This option is not implemented yet.") |
| 98 | + if (len(onx.graph.node) not in (1, 2) or not onx.graph.node[0].op_type.startswith("TreeEnsemble") or |
| 99 | + (len(onx.graph.node) == 2 and onx.graph.node[1].op_type != "ZipMap")): |
| 100 | + types = ", ".join(map(lambda n: n.op_type, onx.graph.node)) |
| 101 | + raise NotImplementedError( |
| 102 | + f"CatBoost returns {len(onx.graph.node)} != 1 (types={types}). " |
| 103 | + f"This option is not implemented yet.") |
| 104 | + node = onx.graph.node[0] |
| 105 | + atts = {} |
| 106 | + for att in node.attribute: |
| 107 | + atts[att.name] = get_attribute_value(att) |
| 108 | + container.add_node( |
| 109 | + node.op_type, [operator.inputs[0].full_name], |
| 110 | + [operator.outputs[0].full_name, operator.outputs[1].full_name], |
| 111 | + op_domain=node.domain, op_version=opsets.get(node.domain, None), |
| 112 | + **atts) |
| 113 | + |
| 114 | + |
| 115 | +update_registered_converter( |
| 116 | + CatBoostClassifier, |
| 117 | + 'CatBoostCatBoostClassifier', |
| 118 | + calculate_linear_classifier_output_shapes, |
| 119 | + skl2onnx_convert_catboost, |
| 120 | + parser=skl2onnx_parser_castboost_classifier, |
| 121 | + options={'nocl': [True, False], 'zipmap': [True, False, 'columns']}) |
| 122 | + |
| 123 | +################################## |
| 124 | +# Convert |
| 125 | +# +++++++ |
| 126 | + |
| 127 | +model_onnx = convert_sklearn( |
| 128 | + pipe, 'pipeline_catboost', |
| 129 | + [('input', FloatTensorType([None, 2]))], |
| 130 | + target_opset={'': 12, 'ai.onnx.ml': 2}) |
| 131 | + |
| 132 | +# And save. |
| 133 | +with open("pipeline_catboost.onnx", "wb") as f: |
| 134 | + f.write(model_onnx.SerializeToString()) |
| 135 | + |
| 136 | +########################### |
| 137 | +# Compare the predictions |
| 138 | +# +++++++++++++++++++++++ |
| 139 | +# |
| 140 | +# Predictions with CatBoost. |
| 141 | + |
| 142 | +print("predict", pipe.predict(X[:5])) |
| 143 | +print("predict_proba", pipe.predict_proba(X[:1])) |
| 144 | + |
| 145 | +########################## |
| 146 | +# Predictions with onnxruntime. |
| 147 | + |
| 148 | +sess = rt.InferenceSession("pipeline_catboost.onnx") |
| 149 | + |
| 150 | +pred_onx = sess.run(None, {"input": X[:5].astype(numpy.float32)}) |
| 151 | +print("predict", pred_onx[0]) |
| 152 | +print("predict_proba", pred_onx[1][:1]) |
| 153 | + |
| 154 | +############################# |
| 155 | +# Final graph |
| 156 | +# +++++++++++ |
| 157 | + |
| 158 | +oinf = OnnxInference(model_onnx) |
| 159 | +ax = plot_graphviz(oinf.to_dot()) |
| 160 | +ax.get_xaxis().set_visible(False) |
| 161 | +ax.get_yaxis().set_visible(False) |
0 commit comments