Skip to content

Commit 36e1dcd

Browse files
authored
add QuadraticDiscriminantAnalysis converter (#915)
* add QDA converter Signed-off-by: xiaowuhu <[email protected]> * fix flake8 Signed-off-by: xiaowuhu <[email protected]> * another flake8 Signed-off-by: xiaowuhu <[email protected]> * Update _supported_operators.py Signed-off-by: xiaowuhu <[email protected]> * Update test_quadratic_discriminant_analysis_converter.py Signed-off-by: xiaowuhu <[email protected]> * Update test_quadratic_discriminant_analysis_converter.py Signed-off-by: xiaowuhu <[email protected]> * Update quadratic_discriminant_analysis.py Signed-off-by: xiaowuhu <[email protected]> * Update quadratic_discriminant_analysis.py Signed-off-by: xiaowuhu <[email protected]> * Update quadratic_discriminant_analysis.py Signed-off-by: xiaowuhu <[email protected]> * Update test_quadratic_discriminant_analysis_converter.py Signed-off-by: xiaowuhu <[email protected]> * add double dtype testing as output Signed-off-by: xiaowuhu <[email protected]> * Update test_quadratic_discriminant_analysis_converter.py Signed-off-by: xiaowuhu <[email protected]> * add double output type test case Signed-off-by: xiaowuhu <[email protected]> * Update test_quadratic_discriminant_analysis_converter.py Signed-off-by: xiaowuhu <[email protected]> * change file name to standard one Signed-off-by: xiaowuhu <[email protected]> * upgrade version to 1.13 Signed-off-by: xiaowuhu <[email protected]> Signed-off-by: xiaowuhu <[email protected]>
1 parent 9a86b5e commit 36e1dcd

7 files changed

+339
-2
lines changed

skl2onnx/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44
Main entry point to the converter from the *scikit-learn* to *onnx*.
55
"""
6-
__version__ = "1.12"
6+
__version__ = "1.13"
77
__author__ = "Microsoft"
88
__producer__ = "skl2onnx"
99
__producer_version__ = __version__

skl2onnx/_supported_operators.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@
6464
SGDOneClassSVM = None
6565

6666
from sklearn.svm import LinearSVR
67-
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
67+
from sklearn.discriminant_analysis import (
68+
LinearDiscriminantAnalysis,
69+
QuadraticDiscriminantAnalysis
70+
)
6871

6972
# Mixture
7073
from sklearn.mixture import (
@@ -288,6 +291,7 @@
288291
OneVsRestClassifier,
289292
PassiveAggressiveClassifier,
290293
Perceptron,
294+
QuadraticDiscriminantAnalysis,
291295
RandomForestClassifier,
292296
SGDClassifier,
293297
StackingClassifier,
@@ -383,6 +387,7 @@ def build_sklearn_operator_name_map():
383387
PoissonRegressor,
384388
PolynomialFeatures,
385389
PowerTransformer,
390+
QuadraticDiscriminantAnalysis,
386391
RadiusNeighborsClassifier,
387392
RadiusNeighborsRegressor,
388393
RandomForestClassifier,

skl2onnx/operator_converters/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from . import pipelines
4949
from . import polynomial_features
5050
from . import power_transformer
51+
from . import quadratic_discriminant_analysis
5152
from . import random_forest
5253
from . import random_projection
5354
from . import random_trees_embedding
@@ -112,6 +113,7 @@
112113
pipelines,
113114
polynomial_features,
114115
power_transformer,
116+
quadratic_discriminant_analysis,
115117
random_forest,
116118
random_projection,
117119
random_trees_embedding,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
from ..common._apply_operation import (
5+
apply_add, apply_argmax, apply_cast, apply_concat, apply_div, apply_exp,
6+
apply_log, apply_matmul, apply_mul, apply_pow,
7+
apply_reducesum, apply_reshape, apply_sub, apply_transpose)
8+
from ..common.data_types import (
9+
BooleanTensorType, Int64TensorType, guess_proto_type)
10+
from ..common._registration import register_converter
11+
from ..common._topology import Scope, Operator
12+
from ..common._container import ModelComponentContainer
13+
from ..proto import onnx_proto
14+
15+
16+
def convert_quadratic_discriminant_analysis_classifier(
17+
scope: Scope, operator: Operator, container: ModelComponentContainer):
18+
19+
input_name = operator.inputs[0].full_name
20+
model = operator.raw_operator
21+
22+
n_classes = len(model.classes_)
23+
24+
proto_dtype = guess_proto_type(operator.inputs[0].type)
25+
if proto_dtype != onnx_proto.TensorProto.DOUBLE:
26+
proto_dtype = onnx_proto.TensorProto.FLOAT
27+
28+
if isinstance(operator.inputs[0].type,
29+
(BooleanTensorType, Int64TensorType)):
30+
cast_input_name = scope.get_unique_variable_name('cast_input')
31+
apply_cast(scope, operator.input_full_names, cast_input_name,
32+
container, to=proto_dtype)
33+
input_name = cast_input_name
34+
35+
norm_array_name = []
36+
sum_array_name = []
37+
38+
container.add_initializer('const_n05', proto_dtype, [], [-0.5])
39+
container.add_initializer('const_p2', proto_dtype, [], [2])
40+
41+
for i in range(n_classes):
42+
R = model.rotations_[i]
43+
rotation_name = scope.get_unique_variable_name('rotations')
44+
container.add_initializer(rotation_name, proto_dtype,
45+
[R.shape[0], R.shape[1]], R)
46+
47+
S = model.scalings_[i]
48+
scaling_name = scope.get_unique_variable_name('scalings')
49+
container.add_initializer(
50+
scaling_name, proto_dtype, [S.shape[0], ], S)
51+
52+
mean = model.means_[i]
53+
mean_name = scope.get_unique_variable_name('means')
54+
container.add_initializer(mean_name, proto_dtype, mean.shape, mean)
55+
56+
Xm_name = scope.get_unique_variable_name('Xm')
57+
apply_sub(scope, [input_name, mean_name], [Xm_name], container)
58+
59+
s_pow_name = scope.get_unique_variable_name('s_pow_n05')
60+
apply_pow(scope, [scaling_name, 'const_n05'], [s_pow_name], container)
61+
62+
mul_name = scope.get_unique_variable_name('mul')
63+
apply_mul(scope, [rotation_name, s_pow_name], [mul_name], container)
64+
65+
x2_name = scope.get_unique_variable_name('matmul')
66+
apply_matmul(scope, [Xm_name, mul_name], [x2_name], container)
67+
68+
pow_x2_name = scope.get_unique_variable_name('pow_x2')
69+
apply_pow(scope, [x2_name, 'const_p2'], [pow_x2_name], container)
70+
71+
sum_name = scope.get_unique_variable_name('sum')
72+
apply_reducesum(scope, [pow_x2_name], [sum_name],
73+
container, axes=[1], keepdims=1)
74+
norm_array_name.append(sum_name)
75+
76+
log_name = scope.get_unique_variable_name('log')
77+
apply_log(scope, [scaling_name], [log_name], container)
78+
79+
sum_log_name = scope.get_unique_variable_name('sum_log')
80+
apply_reducesum(
81+
scope, [log_name], [sum_log_name], container, keepdims=1)
82+
sum_array_name.append(sum_log_name)
83+
84+
concat_norm_name = scope.get_unique_variable_name('concat_norm')
85+
apply_concat(scope, norm_array_name, [concat_norm_name], container)
86+
87+
reshape_norm_name = scope.get_unique_variable_name('reshape_concat_norm')
88+
apply_reshape(scope, [concat_norm_name], [reshape_norm_name],
89+
container, desired_shape=[n_classes, -1])
90+
91+
transpose_norm_name = scope.get_unique_variable_name('transpose_norm')
92+
apply_transpose(scope, [reshape_norm_name], [transpose_norm_name],
93+
container, perm=(1, 0))
94+
95+
apply_concat(scope, sum_array_name, ['concat_logsum'], container)
96+
97+
add_norm2_u_name = scope.get_unique_variable_name('add_norm2_u')
98+
apply_add(scope, [transpose_norm_name, 'concat_logsum'],
99+
[add_norm2_u_name], container)
100+
101+
norm2_u_n05_name = scope.get_unique_variable_name('norm2_u_n05')
102+
apply_mul(
103+
scope, ['const_n05', add_norm2_u_name], [norm2_u_n05_name], container)
104+
105+
container.add_initializer(
106+
'priors', proto_dtype, [n_classes, ], model.priors_)
107+
apply_log(scope, ['priors'], ['log_p'], container)
108+
109+
apply_add(scope, [norm2_u_n05_name, 'log_p'], ['decision_fun'], container)
110+
111+
apply_argmax(scope, ['decision_fun'], ['argmax_out'], container, axis=1)
112+
113+
container.add_initializer(
114+
'classes', onnx_proto.TensorProto.INT64, [n_classes], model.classes_)
115+
116+
container.add_node(
117+
'ArrayFeatureExtractor',
118+
['classes', 'argmax_out'],
119+
[operator.outputs[0].full_name],
120+
op_domain='ai.onnx.ml'
121+
)
122+
123+
attr = {'axes': [1]}
124+
container.add_node(
125+
'ReduceMax', ['decision_fun'], ['df_max'], **attr)
126+
apply_sub(scope, ['decision_fun', 'df_max'], ['df_sub_max'], container)
127+
apply_exp(scope, ['df_sub_max'], ['likelihood'], container)
128+
apply_reducesum(scope, ['likelihood'], ['likelihood_sum'], container,
129+
axes=[1], keepdims=1)
130+
apply_div(scope, ['likelihood', 'likelihood_sum'],
131+
[operator.outputs[1].full_name], container, )
132+
133+
134+
register_converter('SklearnQuadraticDiscriminantAnalysis',
135+
convert_quadratic_discriminant_analysis_classifier,
136+
options={'zipmap': [True, False, 'columns'],
137+
'nocl': [True, False],
138+
'output_class_labels': [False, True]})

skl2onnx/shape_calculators/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from . import pipelines
3737
from . import polynomial_features
3838
from . import power_transformer
39+
from . import quadratic_discriminant_analysis
3940
from . import random_projection
4041
from . import random_trees_embedding
4142
from . import replace_op
@@ -84,6 +85,7 @@
8485
pipelines,
8586
polynomial_features,
8687
power_transformer,
88+
quadratic_discriminant_analysis,
8789
random_projection,
8890
random_trees_embedding,
8991
replace_op,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from ..common._registration import register_shape_calculator
4+
from ..common.data_types import Int64TensorType
5+
6+
7+
def calculate_quadratic_discriminant_analysis_shapes(operator):
8+
N = len(operator.raw_operator.classes_)
9+
operator.outputs[0].type = Int64TensorType([1, N])
10+
operator.outputs[1].type.shape = [None, N]
11+
12+
13+
register_shape_calculator(
14+
'SklearnQuadraticDiscriminantAnalysis',
15+
calculate_quadratic_discriminant_analysis_shapes)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
"""Tests scikit-learn's SGDClassifier converter."""
4+
5+
import sklearn
6+
import unittest
7+
import numpy as np
8+
import packaging.version as pv
9+
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
10+
from onnxruntime import __version__ as ort_version
11+
from onnx import __version__ as onnx_version
12+
from skl2onnx import convert_sklearn
13+
from skl2onnx.common.data_types import (
14+
FloatTensorType,
15+
DoubleTensorType
16+
)
17+
18+
from test_utils import (
19+
dump_data_and_model,
20+
TARGET_OPSET
21+
)
22+
23+
ort_version = ".".join(ort_version.split(".")[:2])
24+
onnx_version = ".".join(onnx_version.split('.')[:2])
25+
26+
27+
class TestQuadraticDiscriminantAnalysisConverter(unittest.TestCase):
28+
@unittest.skipIf(pv.Version(sklearn.__version__) < pv.Version('1.0'),
29+
reason="scikit-learn<1.0")
30+
@unittest.skipIf(pv.Version(onnx_version) < pv.Version('1.11'),
31+
reason="fails with onnx 1.10")
32+
def test_model_qda_2c2f_float(self):
33+
# 2 classes, 2 features
34+
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
35+
y = np.array([1, 1, 1, 2, 2, 2])
36+
X_test = np.array([[-0.8, -1], [0.8, 1]])
37+
38+
skl_model = QuadraticDiscriminantAnalysis()
39+
skl_model.fit(X, y)
40+
41+
onnx_model = convert_sklearn(
42+
skl_model,
43+
"scikit-learn QDA",
44+
[("input", FloatTensorType([None, X.shape[1]]))],
45+
target_opset=TARGET_OPSET)
46+
47+
self.assertIsNotNone(onnx_model)
48+
dump_data_and_model(X_test.astype(np.float32), skl_model, onnx_model,
49+
basename="SklearnQDA_2c2f_Float")
50+
51+
@unittest.skipIf(pv.Version(sklearn.__version__) < pv.Version('1.0'),
52+
reason="scikit-learn<1.0")
53+
@unittest.skipIf(pv.Version(onnx_version) < pv.Version('1.11'),
54+
reason="fails with onnx 1.10")
55+
def test_model_qda_2c3f_float(self):
56+
# 2 classes, 3 features
57+
X = np.array([[-1, -1, 0], [-2, -1, 1], [-3, -2, 0],
58+
[1, 1, 0], [2, 1, 1], [3, 2, 1]])
59+
y = np.array([1, 1, 1, 2, 2, 2])
60+
X_test = np.array([[-0.8, -1, 0], [-1, -1.6, 0],
61+
[1, 1.5, 1], [3.1, 2.1, 1]])
62+
63+
skl_model = QuadraticDiscriminantAnalysis()
64+
skl_model.fit(X, y)
65+
66+
onnx_model = convert_sklearn(
67+
skl_model,
68+
"scikit-learn QDA",
69+
[("input", FloatTensorType([None, X.shape[1]]))],
70+
target_opset=TARGET_OPSET)
71+
72+
self.assertIsNotNone(onnx_model)
73+
dump_data_and_model(X_test.astype(np.float32), skl_model, onnx_model,
74+
basename="SklearnQDA_2c3f_Float")
75+
76+
@unittest.skipIf(pv.Version(sklearn.__version__) < pv.Version('1.0'),
77+
reason="scikit-learn<1.0")
78+
@unittest.skipIf(pv.Version(onnx_version) < pv.Version('1.11'),
79+
reason="fails with onnx 1.10")
80+
def test_model_qda_3c2f_float(self):
81+
# 3 classes, 2 features
82+
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1],
83+
[2, 1], [3, 2], [-1, 2], [-2, 3], [-2, 2]])
84+
y = np.array([1, 1, 1, 2, 2, 2, 3, 3, 3])
85+
X_test = np.array([[-0.8, -1], [0.8, 1], [-0.8, 1]])
86+
87+
skl_model = QuadraticDiscriminantAnalysis()
88+
skl_model.fit(X, y)
89+
90+
onnx_model = convert_sklearn(
91+
skl_model,
92+
"scikit-learn QDA",
93+
[("input", FloatTensorType([None, X.shape[1]]))],
94+
target_opset=TARGET_OPSET)
95+
96+
self.assertIsNotNone(onnx_model)
97+
dump_data_and_model(X_test.astype(np.float32), skl_model, onnx_model,
98+
basename="SklearnQDA_3c2f_Float")
99+
100+
@unittest.skipIf(pv.Version(sklearn.__version__) < pv.Version('1.0'),
101+
reason="scikit-learn<1.0")
102+
@unittest.skipIf(pv.Version(onnx_version) < pv.Version('1.11'),
103+
reason="fails with onnx 1.10")
104+
def test_model_qda_2c2f_double(self):
105+
# 2 classes, 2 features
106+
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1],
107+
[2, 1], [3, 2]]).astype(np.double)
108+
y = np.array([1, 1, 1, 2, 2, 2])
109+
X_test = np.array([[-0.8, -1], [0.8, 1]])
110+
111+
skl_model = QuadraticDiscriminantAnalysis()
112+
skl_model.fit(X, y)
113+
114+
onnx_model = convert_sklearn(
115+
skl_model,
116+
"scikit-learn QDA",
117+
[("input", DoubleTensorType([None, X.shape[1]]))],
118+
target_opset=TARGET_OPSET, options={'zipmap': False})
119+
120+
self.assertIsNotNone(onnx_model)
121+
dump_data_and_model(X_test.astype(np.double), skl_model, onnx_model,
122+
basename="SklearnQDA_2c2f_Double")
123+
124+
@unittest.skipIf(pv.Version(sklearn.__version__) < pv.Version('1.0'),
125+
reason="scikit-learn<1.0")
126+
@unittest.skipIf(pv.Version(onnx_version) < pv.Version('1.11'),
127+
reason="fails with onnx 1.10")
128+
def test_model_qda_2c3f_double(self):
129+
# 2 classes, 3 features
130+
X = np.array([[-1, -1, 0], [-2, -1, 1], [-3, -2, 0],
131+
[1, 1, 0], [2, 1, 1], [3, 2, 1]]).astype(np.double)
132+
y = np.array([1, 1, 1, 2, 2, 2])
133+
X_test = np.array([[-0.8, -1, 0], [-1, -1.6, 0],
134+
[1, 1.5, 1], [3.1, 2.1, 1]])
135+
136+
skl_model = QuadraticDiscriminantAnalysis()
137+
skl_model.fit(X, y)
138+
139+
onnx_model = convert_sklearn(
140+
skl_model,
141+
"scikit-learn QDA",
142+
[("input", DoubleTensorType([None, X.shape[1]]))],
143+
target_opset=TARGET_OPSET, options={'zipmap': False})
144+
145+
self.assertIsNotNone(onnx_model)
146+
dump_data_and_model(X_test.astype(np.double), skl_model, onnx_model,
147+
basename="SklearnQDA_2c3f_Double")
148+
149+
@unittest.skipIf(pv.Version(sklearn.__version__) < pv.Version('1.0'),
150+
reason="scikit-learn<1.0")
151+
@unittest.skipIf(pv.Version(onnx_version) < pv.Version('1.11'),
152+
reason="fails with onnx 1.10")
153+
def test_model_qda_3c2f_double(self):
154+
# 3 classes, 2 features
155+
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2],
156+
[-1, 2], [-2, 3], [-2, 2]]).astype(np.double)
157+
y = np.array([1, 1, 1, 2, 2, 2, 3, 3, 3])
158+
X_test = np.array([[-0.8, -1], [0.8, 1], [-0.8, 1]])
159+
160+
skl_model = QuadraticDiscriminantAnalysis()
161+
skl_model.fit(X, y)
162+
163+
onnx_model = convert_sklearn(
164+
skl_model,
165+
"scikit-learn QDA",
166+
[("input", DoubleTensorType([None, X.shape[1]]))],
167+
target_opset=TARGET_OPSET, options={'zipmap': False})
168+
169+
self.assertIsNotNone(onnx_model)
170+
dump_data_and_model(X_test.astype(np.double), skl_model, onnx_model,
171+
basename="SklearnQDA_3c2f_Double")
172+
173+
174+
if __name__ == "__main__":
175+
unittest.main(verbosity=3)

0 commit comments

Comments
 (0)