Skip to content

Commit 3ef5e13

Browse files
authored
Fix converter for DecisionTreeClassifier if n_classses == 1 (#1008)
* Fix converter for DecisionTreeClassifier if n_classses == 1 Signed-off-by: Xavier Dupre <[email protected]> * list or np.array Signed-off-by: Xavier Dupre <[email protected]> * lint Signed-off-by: Xavier Dupre <[email protected]> * froze lightgbm version Signed-off-by: Xavier Dupre <[email protected]> * black Signed-off-by: Xavier Dupre <[email protected]> * Refactor with black (#1009) * Refactor with black Signed-off-by: Xavier Dupre <[email protected]> * remove unnecessary skip condition Signed-off-by: Xavier Dupre <[email protected]> * freeze lightgbm version Signed-off-by: Xavier Dupre <[email protected]> * add ruff to github action Signed-off-by: Xavier Dupre <[email protected]> * update badge on README.md Signed-off-by: Xavier Dupre <[email protected]> --------- Signed-off-by: Xavier Dupre <[email protected]> * fix old CI Signed-off-by: Xavier Dupre <[email protected]> --------- Signed-off-by: Xavier Dupre <[email protected]> Signed-off-by: Xavier Dupré <[email protected]>
1 parent 8a4a803 commit 3ef5e13

File tree

4 files changed

+97
-6
lines changed

4 files changed

+97
-6
lines changed

docs/examples/plot_cast_transformer.py

-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
"""
3131
import onnxruntime
3232
import onnx
33-
import numpy
3433
import os
3534
import math
3635
import numpy as np

docs/examples/plot_tfidfvectorizer.py

-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import matplotlib.pyplot as plt
2525
import os
2626
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
27-
import numpy
2827
import onnxruntime as rt
2928
from skl2onnx.common.data_types import StringTensorType
3029
from skl2onnx import convert_sklearn

skl2onnx/operator_converters/decision_tree.py

+52-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numbers
55
import numpy as np
6+
from onnx.numpy_helper import from_array
67
from ..common._apply_operation import (
78
apply_cast,
89
apply_concat,
@@ -124,7 +125,7 @@ def predict(
124125
[indices_name, dummy_proba_name],
125126
op_domain=op_domain,
126127
op_version=op_version,
127-
**attrs
128+
**attrs,
128129
)
129130
else:
130131
zero_name = scope.get_unique_variable_name("zero")
@@ -243,7 +244,7 @@ def _append_decision_output(
243244
dpath,
244245
op_domain=op_domain,
245246
op_version=op_version,
246-
**attrs
247+
**attrs,
247248
)
248249

249250
if n_out is None:
@@ -306,6 +307,53 @@ def convert_sklearn_decision_tree_classifier(
306307
dtype = np.float32
307308
op = operator.raw_operator
308309
options = scope.get_options(op, dict(decision_path=False, decision_leaf=False))
310+
if np.asarray(op.classes_).size == 1:
311+
# The model was trained with one label.
312+
# There is no need to build a tree.
313+
if op.n_outputs_ != 1:
314+
raise RuntimeError(
315+
f"One training class and multiple outputs is not "
316+
f"supported yet for class {op.__class__.__name__!r}."
317+
)
318+
if options["decision_path"] or options["decision_leaf"]:
319+
raise RuntimeError(
320+
f"One training class, option 'decision_path' "
321+
f"or 'decision_leaf' are not supported for "
322+
f"class {op.__class__.__name__!r}."
323+
)
324+
325+
zero = scope.get_unique_variable_name("zero")
326+
one = scope.get_unique_variable_name("one")
327+
new_shape = scope.get_unique_variable_name("new_shape")
328+
container.add_initializer(zero, onnx_proto.TensorProto.INT64, [1], [0])
329+
container.add_initializer(one, onnx_proto.TensorProto.INT64, [1], [1])
330+
container.add_initializer(new_shape, onnx_proto.TensorProto.INT64, [2], [-1, 1])
331+
shape = scope.get_unique_variable_name("shape")
332+
container.add_node("Shape", [operator.inputs[0].full_name], [shape])
333+
shape_sliced = scope.get_unique_variable_name("shape_sliced")
334+
container.add_node("Slice", [shape, zero, one, zero], [shape_sliced])
335+
336+
# labels
337+
container.add_node(
338+
"ConstantOfShape",
339+
[shape_sliced],
340+
[operator.outputs[0].full_name],
341+
value=from_array(np.array([op.classes_[0]], dtype=np.int64)),
342+
)
343+
344+
# probabilities
345+
probas = scope.get_unique_variable_name("probas")
346+
container.add_node(
347+
"ConstantOfShape",
348+
[shape_sliced],
349+
[probas],
350+
value=from_array(np.array([1], dtype=dtype)),
351+
)
352+
container.add_node(
353+
"Reshape", [probas, new_shape], [operator.outputs[1].full_name]
354+
)
355+
return
356+
309357
if op.n_outputs_ == 1:
310358
attrs = get_default_tree_classifier_attribute_pairs()
311359
attrs["name"] = scope.get_unique_operator_name(op_type)
@@ -355,7 +403,7 @@ def convert_sklearn_decision_tree_classifier(
355403
[operator.outputs[0].full_name, operator.outputs[1].full_name],
356404
op_domain=op_domain,
357405
op_version=op_version,
358-
**attrs
406+
**attrs,
359407
)
360408

361409
n_out = 2
@@ -510,7 +558,7 @@ def convert_sklearn_decision_tree_regressor(
510558
operator.outputs[0].full_name,
511559
op_domain=op_domain,
512560
op_version=op_version,
513-
**attrs
561+
**attrs,
514562
)
515563

516564
options = scope.get_options(op, dict(decision_path=False, decision_leaf=False))
+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import unittest
4+
import numpy as np
5+
6+
try:
7+
from onnx.reference import ReferenceEvaluator
8+
except ImportError:
9+
ReferenceEvaluator = None
10+
from sklearn.tree import DecisionTreeClassifier
11+
from onnxruntime import InferenceSession
12+
from skl2onnx import to_onnx
13+
from test_utils import TARGET_OPSET
14+
15+
16+
class TestSklearnClassifiersExtreme(unittest.TestCase):
17+
def test_one_training_class(self):
18+
x = np.eye(4, dtype=np.float32)
19+
y = np.array([5, 5, 5, 5], dtype=np.int64)
20+
21+
cl = DecisionTreeClassifier()
22+
cl = cl.fit(x, y)
23+
24+
expected = [cl.predict(x), cl.predict_proba(x)]
25+
onx = to_onnx(cl, x, target_opset=TARGET_OPSET, options={"zipmap": False})
26+
27+
for cls in [
28+
(lambda onx: ReferenceEvaluator(onx, verbose=0))
29+
if ReferenceEvaluator is not None
30+
else None,
31+
lambda onx: InferenceSession(
32+
onx.SerializeToString(), providers=["CPUExecutionProvider"]
33+
),
34+
]:
35+
if cls is None:
36+
continue
37+
sess = cls(onx)
38+
res = sess.run(None, {"X": x})
39+
self.assertEqual(len(res), len(expected))
40+
for e, g in zip(expected, res):
41+
self.assertEqual(e.tolist(), g.tolist())
42+
43+
44+
if __name__ == "__main__":
45+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)