Skip to content

Commit 2260310

Browse files
xiaowuhuxadupre
andauthored
Xiaowu/one vs one classifier (#904)
* add essential files Signed-off-by: xiaowuhu <[email protected]> * update files Signed-off-by: xiaowuhu <[email protected]> * updated Signed-off-by: xiaowuhu <[email protected]> * make the converter working Signed-off-by: xiaowuhu <[email protected]> * Update test_sklearn_one_vs_one_classifier_converter.py Signed-off-by: xiaowuhu <[email protected]> * Update one_vs_one_classifier.py Signed-off-by: xiaowuhu <[email protected]> * update files Signed-off-by: xiaowuhu <[email protected]> * first fix for ovo converter Signed-off-by: xadupre <[email protected]> * fix ovo, still an issue with LogisiticRegression and DecisionTree Signed-off-by: xadupre <[email protected]> * Update requirements.txt Signed-off-by: xiaowuhu <[email protected]> * Update requirements.txt Signed-off-by: xiaowuhu <[email protected]> * remove unnecessary files Signed-off-by: xiaowuhu <[email protected]> * fix ovo Signed-off-by: xadupre <[email protected]> * final fix for ovo Signed-off-by: xadupre <[email protected]> * remove unnecessary option Signed-off-by: xadupre <[email protected]> * lint issues Signed-off-by: xadupre <[email protected]> * update ci Signed-off-by: xadupre <[email protected]> * Update linux-conda-CI.yml Signed-off-by: xiaowuhu <[email protected]> * change CI Signed-off-by: xiaowuhu <[email protected]> Signed-off-by: xiaowuhu <[email protected]> Signed-off-by: xadupre <[email protected]> Co-authored-by: xadupre <[email protected]>
1 parent 9ece520 commit 2260310

13 files changed

+407
-115
lines changed

.azure-pipelines/linux-conda-CI.yml

+3-1
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,10 @@ jobs:
268268
displayName: 'pytest'
269269
270270
- script: |
271+
# some of this is triggering the following error when importing scipy on python 3.10
272+
# ImportError: /lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.29'
271273
conda install -c conda-forge "lightgbm${lgbm.version}" xgboost --no-deps
272-
pip install xgboost lightgbm hummingbird-ml hummingbird
274+
pip install hummingbird-ml hummingbird xgboost lightgbm
273275
pip install --no-deps git+https://github.com/microsoft/onnxconverter-common.git
274276
pip install onnxmltools
275277
displayName: 'install onnxmltools'

.azure-pipelines/win32-conda-CI.yml

+1-18
Original file line numberDiff line numberDiff line change
@@ -121,24 +121,7 @@ jobs:
121121
onnxrt.version: 'onnxruntime==1.7.0' # -i https://test.pypi.org/simple/ ort-nightly'
122122
onnxcc.version: 'onnxconverter-common==1.7.0' # git+https://github.com/microsoft/onnxconverter-common.git
123123
sklearn.version: '==0.24.1'
124-
Py38-Onnx181-Rt160-Skl0240:
125-
python.version: '3.8'
126-
onnx.version: 'onnx==1.8.1'
127-
onnx.target_opset: ''
128-
numpy.version: 'numpy>=1.18.1'
129-
scipy.version: 'scipy'
130-
onnxrt.version: 'onnxruntime==1.6.0'
131-
onnxcc.version: 'onnxconverter-common==1.7.0'
132-
sklearn.version: '==0.24.0'
133-
Py38-Onnx170-Rt160-Skl0240:
134-
python.version: '3.8'
135-
onnx.version: 'onnx==1.7.0'
136-
onnx.target_opset: ''
137-
numpy.version: 'numpy>=1.18.1'
138-
scipy.version: 'scipy'
139-
onnxrt.version: 'onnxruntime==1.6.0'
140-
onnxcc.version: 'onnxconverter-common==1.7.0'
141-
sklearn.version: '==0.24.0'
124+
142125

143126
maxParallel: 3
144127

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ onnx>=1.2.1
55
scikit-learn>=0.19
66
scikit-learn<=1.1.1
77
onnxconverter-common>=1.7.0
8+
scikit-learn<=1.1.1

skl2onnx/_supported_operators.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
)
7373

7474
# Multi-class
75-
from sklearn.multiclass import OneVsRestClassifier
75+
from sklearn.multiclass import OneVsRestClassifier, OneVsOneClassifier
7676

7777
# Tree-based models
7878
from sklearn.ensemble import (
@@ -284,6 +284,7 @@
284284
MLPClassifier,
285285
MultinomialNB,
286286
NuSVC,
287+
OneVsOneClassifier,
287288
OneVsRestClassifier,
288289
PassiveAggressiveClassifier,
289290
Perceptron,
@@ -373,6 +374,7 @@ def build_sklearn_operator_name_map():
373374
Normalizer,
374375
OneClassSVM,
375376
OneHotEncoder,
377+
OneVsOneClassifier,
376378
OneVsRestClassifier,
377379
OrdinalEncoder,
378380
PCA,

skl2onnx/operator_converters/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from . import nearest_neighbours
4242
from . import normaliser
4343
from . import one_hot_encoder
44+
from . import one_vs_one_classifier
4445
from . import one_vs_rest_classifier
4546
from . import ordinal_encoder
4647
from . import ovr_decision_function
@@ -104,6 +105,7 @@
104105
nearest_neighbours,
105106
normaliser,
106107
one_hot_encoder,
108+
one_vs_one_classifier,
107109
one_vs_rest_classifier,
108110
ordinal_encoder,
109111
ovr_decision_function,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from sklearn.base import is_regressor
4+
from ..proto import onnx_proto
5+
from ..common._registration import register_converter
6+
from ..common._topology import Scope, Operator
7+
from ..common._container import ModelComponentContainer
8+
from ..common._apply_operation import apply_cast, apply_concat, apply_reshape
9+
from ..common.data_types import guess_proto_type, Int64TensorType
10+
from .._supported_operators import sklearn_operator_name_map
11+
12+
13+
def _iteration_one_versus(scope, container, inputs, i, estimator, cl_type,
14+
proto_dtype, use_raw_scores=True, prob_shape=None):
15+
op_type = sklearn_operator_name_map[type(estimator)]
16+
17+
this_operator = scope.declare_local_operator(op_type, raw_model=estimator)
18+
this_operator.inputs = inputs
19+
20+
if is_regressor(estimator):
21+
score_name = scope.declare_local_variable('score_%d' % i, cl_type())
22+
this_operator.outputs.append(score_name)
23+
24+
if hasattr(estimator, 'coef_') and len(estimator.coef_.shape) == 2:
25+
raise RuntimeError(
26+
"OneVsRestClassifier or OneVsOneClassifier accepts "
27+
"regressor with only one target.")
28+
p1 = score_name.onnx_name
29+
return None, None, p1
30+
31+
if container.has_options(estimator, 'raw_scores'):
32+
options = {'raw_scores': use_raw_scores}
33+
elif container.has_options(estimator, 'zipmap'):
34+
options = {'zipmap': False}
35+
else:
36+
options = None
37+
if options is not None:
38+
container.add_options(id(estimator), options)
39+
scope.add_options(id(estimator), options)
40+
41+
label_name = scope.declare_local_variable(
42+
'label_%d' % i, Int64TensorType())
43+
prob_name = scope.declare_local_variable(
44+
'proba_%d' % i, inputs[0].type.__class__())
45+
this_operator.outputs.append(label_name)
46+
this_operator.outputs.append(prob_name)
47+
48+
# gets the label for the class 1
49+
label = scope.get_unique_variable_name('lab_%d' % i)
50+
apply_reshape(scope, label_name.onnx_name, label, container,
51+
desired_shape=(-1, 1))
52+
cast_label = scope.get_unique_variable_name('cast_lab_%d' % i)
53+
apply_cast(scope, label, cast_label, container,
54+
to=proto_dtype)
55+
56+
# get the probability for the class 1
57+
if prob_shape is None:
58+
# shape to use to reshape score
59+
cst0 = scope.get_unique_variable_name('cst0')
60+
container.add_initializer(cst0, onnx_proto.TensorProto.INT64, [1], [0])
61+
shape = scope.get_unique_variable_name('shape')
62+
container.add_node('Shape', [inputs[0].full_name], [shape])
63+
first_dim = scope.get_unique_variable_name('dim')
64+
container.add_node('Gather', [shape, cst0], [first_dim])
65+
cst_1 = scope.get_unique_variable_name('cst_1')
66+
container.add_initializer(
67+
cst_1, onnx_proto.TensorProto.INT64, [1], [-1])
68+
prob_shape = scope.get_unique_variable_name('shape')
69+
apply_concat(scope, [first_dim, cst_1], prob_shape, container, axis=0)
70+
71+
prob_reshaped = scope.get_unique_variable_name('prob_%d' % i)
72+
container.add_node('Reshape', [prob_name.onnx_name, prob_shape],
73+
[prob_reshaped])
74+
75+
cst1 = scope.get_unique_variable_name('cst1')
76+
container.add_initializer(cst1, onnx_proto.TensorProto.INT64, [1], [1])
77+
cst2 = scope.get_unique_variable_name('cst2')
78+
container.add_initializer(cst2, onnx_proto.TensorProto.INT64, [1], [2])
79+
80+
prob1 = scope.get_unique_variable_name('prob1_%d' % i)
81+
container.add_node(
82+
'Slice', [prob_reshaped, cst1, cst2, cst1], prob1)
83+
return prob_shape, cast_label, prob1
84+
85+
86+
def convert_one_vs_one_classifier(scope: Scope, operator: Operator,
87+
container: ModelComponentContainer):
88+
89+
proto_dtype = guess_proto_type(operator.inputs[0].type)
90+
if proto_dtype != onnx_proto.TensorProto.DOUBLE:
91+
proto_dtype = onnx_proto.TensorProto.FLOAT
92+
op = operator.raw_operator
93+
94+
# shape to use to reshape score
95+
cst0 = scope.get_unique_variable_name('cst0')
96+
container.add_initializer(cst0, onnx_proto.TensorProto.INT64, [1], [0])
97+
cst1 = scope.get_unique_variable_name('cst1')
98+
container.add_initializer(cst1, onnx_proto.TensorProto.INT64, [1], [1])
99+
cst2 = scope.get_unique_variable_name('cst2')
100+
container.add_initializer(cst2, onnx_proto.TensorProto.INT64, [1], [2])
101+
shape = scope.get_unique_variable_name('shape')
102+
container.add_node('Shape', [operator.inputs[0].full_name], [shape])
103+
first_dim = scope.get_unique_variable_name('dim')
104+
container.add_node('Gather', [shape, cst0], [first_dim])
105+
cst_1 = scope.get_unique_variable_name('cst_1')
106+
container.add_initializer(cst_1, onnx_proto.TensorProto.INT64, [1], [-1])
107+
prob_shape = scope.get_unique_variable_name('shape')
108+
apply_concat(scope, [first_dim, cst_1], prob_shape, container, axis=0)
109+
110+
label_names = []
111+
prob_names = []
112+
prob_shape = None
113+
cl_type = operator.inputs[0].type.__class__
114+
for i, estimator in enumerate(op.estimators_):
115+
prob_shape, cast_label, prob1 = _iteration_one_versus(
116+
scope, container, operator.inputs, i, estimator, cl_type,
117+
proto_dtype, True, prob_shape=prob_shape)
118+
119+
label_names.append(cast_label)
120+
prob_names.append(prob1)
121+
122+
conc_lab_name = scope.get_unique_variable_name('concat_out_ovo_label')
123+
apply_concat(scope, label_names, conc_lab_name, container, axis=1)
124+
conc_prob_name = scope.get_unique_variable_name('concat_out_ovo_prob')
125+
apply_concat(scope, prob_names, conc_prob_name, container, axis=1)
126+
127+
# calls _ovr_decision_function
128+
this_operator = scope.declare_local_operator(
129+
"SklearnOVRDecisionFunction", op)
130+
131+
cl_type = operator.inputs[0].type.__class__
132+
label = scope.declare_local_variable("label", cl_type())
133+
container.add_node('Identity', [conc_lab_name], [label.onnx_name])
134+
prob_score = scope.declare_local_variable("prob_score", cl_type())
135+
container.add_node('Identity', [conc_prob_name], [prob_score.onnx_name])
136+
137+
this_operator.inputs.append(label)
138+
this_operator.inputs.append(prob_score)
139+
140+
ovr_name = scope.declare_local_variable('ovr_output', cl_type())
141+
this_operator.outputs.append(ovr_name)
142+
143+
output_name = operator.outputs[1].full_name
144+
container.add_node('Identity', [ovr_name.onnx_name], [output_name])
145+
146+
container.add_node(
147+
'ArgMax', 'ovr_output', operator.outputs[0].full_name, axis=1)
148+
149+
150+
register_converter('SklearnOneVsOneClassifier',
151+
convert_one_vs_one_classifier,
152+
options={'zipmap': [True, False, 'columns'],
153+
'nocl': [True, False],
154+
'output_class_labels': [False, True]})

skl2onnx/operator_converters/one_vs_rest_classifier.py

+45-34
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from sklearn.base import is_regressor
4+
from sklearn.svm import LinearSVC
45
from ..proto import onnx_proto
56
from ..common._apply_operation import (
67
apply_concat, apply_identity, apply_mul, apply_reshape)
@@ -15,6 +16,45 @@
1516
from .._supported_operators import sklearn_operator_name_map
1617

1718

19+
def _iteration_one_versus(scope, container, inputs, i, estimator, cl_type,
20+
proto_dtype, use_raw_scores=True, prob_shape=None):
21+
op_type = sklearn_operator_name_map[type(estimator)]
22+
23+
this_operator = scope.declare_local_operator(
24+
op_type, raw_model=estimator)
25+
this_operator.inputs = inputs
26+
27+
if is_regressor(estimator):
28+
score_name = scope.declare_local_variable('score_%d' % i, cl_type())
29+
this_operator.outputs.append(score_name)
30+
31+
if hasattr(estimator, 'coef_') and len(estimator.coef_.shape) == 2:
32+
raise RuntimeError(
33+
"OneVsRestClassifier or OneVsOneClassifier accepts "
34+
"regressor with only one target.")
35+
p1 = score_name.onnx_name
36+
else:
37+
if container.has_options(estimator, 'raw_scores'):
38+
container.add_options(
39+
id(estimator), {'raw_scores': use_raw_scores})
40+
scope.add_options(
41+
id(estimator), {'raw_scores': use_raw_scores})
42+
label_name = scope.declare_local_variable(
43+
'label_%d' % i, Int64TensorType())
44+
prob_name = scope.declare_local_variable('proba_%d' % i, cl_type())
45+
this_operator.outputs.append(label_name)
46+
this_operator.outputs.append(prob_name)
47+
48+
# gets the probability for the class 1
49+
p1 = scope.get_unique_variable_name('probY_%d' % i)
50+
if isinstance(estimator, LinearSVC):
51+
apply_identity(scope, prob_name.onnx_name, p1, container)
52+
else:
53+
apply_slice(scope, prob_name.onnx_name, p1, container, starts=[1],
54+
ends=[2], axes=[1])
55+
return None, None, p1
56+
57+
1858
def convert_one_vs_rest_classifier(scope: Scope, operator: Operator,
1959
container: ModelComponentContainer):
2060
"""
@@ -31,41 +71,12 @@ def convert_one_vs_rest_classifier(scope: Scope, operator: Operator,
3171
options = container.get_options(op, dict(raw_scores=False))
3272
use_raw_scores = options['raw_scores']
3373
probs_names = []
74+
cl_type = operator.inputs[0].type.__class__
75+
prob_shape = None
3476
for i, estimator in enumerate(op.estimators_):
35-
op_type = sklearn_operator_name_map[type(estimator)]
36-
37-
this_operator = scope.declare_local_operator(
38-
op_type, raw_model=estimator)
39-
this_operator.inputs = operator.inputs
40-
41-
if is_regressor(estimator):
42-
score_name = scope.declare_local_variable(
43-
'score_%d' % i, operator.inputs[0].type.__class__())
44-
this_operator.outputs.append(score_name)
45-
46-
if hasattr(estimator, 'coef_') and len(estimator.coef_.shape) == 2:
47-
raise RuntimeError("OneVsRestClassifier accepts "
48-
"regressor with only one target.")
49-
p1 = score_name.onnx_name
50-
else:
51-
if container.has_options(estimator, 'raw_scores'):
52-
container.add_options(
53-
id(estimator), {'raw_scores': use_raw_scores})
54-
scope.add_options(
55-
id(estimator), {'raw_scores': use_raw_scores})
56-
label_name = scope.declare_local_variable(
57-
'label_%d' % i, Int64TensorType())
58-
prob_name = scope.declare_local_variable(
59-
'proba_%d' % i, operator.inputs[0].type.__class__())
60-
this_operator.outputs.append(label_name)
61-
this_operator.outputs.append(prob_name)
62-
63-
# gets the probability for the class 1
64-
p1 = scope.get_unique_variable_name('probY_%d' % i)
65-
apply_slice(scope, prob_name.onnx_name, p1, container, starts=[1],
66-
ends=[2], axes=[1],
67-
operator_name=scope.get_unique_operator_name('Slice'))
68-
77+
prob_shape, _, p1 = _iteration_one_versus(
78+
scope, container, operator.inputs, i, estimator, cl_type,
79+
proto_dtype, use_raw_scores, prob_shape=prob_shape)
6980
probs_names.append(p1)
7081

7182
if op.multilabel_:

0 commit comments

Comments
 (0)