2
2
3
3
"""Tests scikit-learn's OrdinalEncoder converter."""
4
4
import unittest
5
+ from numpy .testing import assert_almost_equal
5
6
import packaging .version as pv
6
7
import numpy as np
7
8
import pandas as pd
8
9
import onnxruntime
9
10
from sklearn import __version__ as sklearn_version
11
+ from sklearn .compose import ColumnTransformer
12
+ from sklearn .pipeline import make_pipeline
13
+ from sklearn .ensemble import RandomForestRegressor
10
14
11
15
try :
12
16
from sklearn .preprocessing import OrdinalEncoder
13
17
except ImportError :
14
18
pass
15
- from skl2onnx import convert_sklearn
19
+ from skl2onnx import convert_sklearn , to_onnx
16
20
from skl2onnx .common .data_types import (
17
21
Int64TensorType ,
18
22
StringTensorType ,
@@ -30,6 +34,11 @@ def ordinal_encoder_support():
30
34
return pv .Version (vers ) >= pv .Version ("0.20.0" )
31
35
32
36
37
+ def set_output_support ():
38
+ vers = "." .join (sklearn_version .split ("." )[:2 ])
39
+ return pv .Version (vers ) >= pv .Version ("1.2" )
40
+
41
+
33
42
class TestSklearnOrdinalEncoderConverter (unittest .TestCase ):
34
43
@unittest .skipIf (
35
44
not ordinal_encoder_support (),
@@ -172,6 +181,89 @@ def test_model_ordinal_encoder_cat_list(self):
172
181
data , model , model_onnx , basename = "SklearnOrdinalEncoderCatList"
173
182
)
174
183
184
+ @unittest .skipIf (
185
+ not set_output_support (),
186
+ reason = "'ColumnTransformer' object has no attribute 'set_output'" ,
187
+ )
188
+ @unittest .skipIf (
189
+ not ordinal_encoder_support (),
190
+ reason = "OrdinalEncoder was not available before 0.20" ,
191
+ )
192
+ def test_ordinal_encoder_pipeline_int64 (self ):
193
+ from onnxruntime import InferenceSession
194
+
195
+ data = pd .DataFrame ({"cat" : ["cat2" , "cat1" ], "num" : [0 , 1 ]})
196
+ data ["num" ] = data ["num" ].astype (np .float32 )
197
+ y = np .array ([0 , 1 ], dtype = np .float32 )
198
+ preprocessor = ColumnTransformer (
199
+ transformers = [
200
+ ("cat" , OrdinalEncoder (dtype = np .int64 ), ["cat" ]),
201
+ ("num" , "passthrough" , ["num" ]),
202
+ ],
203
+ sparse_threshold = 1 ,
204
+ verbose_feature_names_out = False ,
205
+ ).set_output (transform = "pandas" )
206
+ model = make_pipeline (
207
+ preprocessor , RandomForestRegressor (n_estimators = 3 , max_depth = 2 )
208
+ )
209
+ model .fit (data , y )
210
+ expected = model .predict (data )
211
+ model_onnx = to_onnx (model , data [:1 ], target_opset = TARGET_OPSET )
212
+ sess = InferenceSession (
213
+ model_onnx .SerializeToString (), providers = ["CPUExecutionProvider" ]
214
+ )
215
+ got = sess .run (
216
+ None ,
217
+ {
218
+ "cat" : data ["cat" ].values .reshape ((- 1 , 1 )),
219
+ "num" : data ["num" ].values .reshape ((- 1 , 1 )),
220
+ },
221
+ )
222
+ assert_almost_equal (expected , got [0 ].ravel ())
223
+
224
+ @unittest .skipIf (
225
+ not set_output_support (),
226
+ reason = "'ColumnTransformer' object has no attribute 'set_output'" ,
227
+ )
228
+ @unittest .skipIf (
229
+ not ordinal_encoder_support (),
230
+ reason = "OrdinalEncoder was not available before 0.20" ,
231
+ )
232
+ def test_ordinal_encoder_pipeline_string_int64 (self ):
233
+ from onnxruntime import InferenceSession
234
+
235
+ data = pd .DataFrame (
236
+ {"C1" : ["cat2" , "cat1" , "cat3" ], "C2" : [1 , 0 , 1 ], "num" : [0 , 1 , 1 ]}
237
+ )
238
+ data ["num" ] = data ["num" ].astype (np .float32 )
239
+ y = np .array ([0 , 1 , 2 ], dtype = np .float32 )
240
+ preprocessor = ColumnTransformer (
241
+ transformers = [
242
+ ("cat" , OrdinalEncoder (dtype = np .int64 ), ["C1" , "C2" ]),
243
+ ("num" , "passthrough" , ["num" ]),
244
+ ],
245
+ sparse_threshold = 1 ,
246
+ verbose_feature_names_out = False ,
247
+ ).set_output (transform = "pandas" )
248
+ model = make_pipeline (
249
+ preprocessor , RandomForestRegressor (n_estimators = 3 , max_depth = 2 )
250
+ )
251
+ model .fit (data , y )
252
+ expected = model .predict (data )
253
+ model_onnx = to_onnx (model , data [:1 ], target_opset = TARGET_OPSET )
254
+ sess = InferenceSession (
255
+ model_onnx .SerializeToString (), providers = ["CPUExecutionProvider" ]
256
+ )
257
+ got = sess .run (
258
+ None ,
259
+ {
260
+ "C1" : data ["C1" ].values .reshape ((- 1 , 1 )),
261
+ "C2" : data ["C2" ].values .reshape ((- 1 , 1 )),
262
+ "num" : data ["num" ].values .reshape ((- 1 , 1 )),
263
+ },
264
+ )
265
+ assert_almost_equal (expected , got [0 ].ravel ())
266
+
175
267
176
268
if __name__ == "__main__" :
177
- unittest .main ()
269
+ unittest .main (verbosity = 2 )
0 commit comments