Skip to content

Commit 77e3b98

Browse files
CloudManXUbuntusrinidhigoud
authored and
Trevor Morris
committed
Sklearn frontend updates (#157)
* Add different path for multiple transform functions of different models and InverseLabelTransform support * typo fixes and copy the value in identity transformation * add shapefunc copy and inverse transform of NALabelEncoder * typo fix * add import tvm.testing * add more opeartor support * adding test cases - test * fix shape mismatch for NALabelEncoder when using dynamic shapes * update and merging for more operator support in sklearn frontend * update and merging for more operator support in sklearn frontend: robustordinalencoder * remove debug prints * reformat with black * pylint format fixes Co-authored-by: Ubuntu <[email protected]> Co-authored-by: srinidhigoud <[email protected]>
1 parent e41640d commit 77e3b98

File tree

3 files changed

+330
-43
lines changed

3 files changed

+330
-43
lines changed

python/tvm/relay/frontend/sklearn.py

+228-28
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# pylint: disable=import-outside-toplevel
1919

2020
import numpy as np
21-
import tvm
21+
from tvm import relay
2222
from tvm.ir import IRModule
2323

2424
from ... import nd as _nd
@@ -107,68 +107,268 @@ def _ThresholdOneHotEncoder(op, inexpr, dshape, dtype, columns=None):
107107

108108
def _RobustStandardScaler(op, inexpr, dshape, dtype, columns=None):
109109
"""
110-
Sagemaker-Scikit-Learn-Extension Transformer:
111-
Standardize features by removing the mean and scaling to unit variance
110+
Sagemaker-Scikit-Learn-Extension Transformer:
111+
Standardize features by removing the mean and scaling to unit variance.
112112
"""
113113
scaler = op.scaler_
114114
ret = _op.subtract(inexpr, _op.const(np.array(scaler.mean_, dtype), dtype))
115115
ret = _op.divide(ret, _op.const(np.array(scaler.scale_, dtype), dtype))
116116
return ret
117117

118-
def _ColumnTransformer(op, inexpr, dshape, dtype, columns=None):
118+
119+
def _ColumnTransformer(op, inexpr, dshape, dtype, func_name, columns=None):
119120
"""
120121
Scikit-Learn Compose:
121122
Applies transformers to columns of an array
122123
"""
123124
out = []
124125
for _, pipe, cols in op.transformers_:
125126
mod = pipe.steps[0][1]
126-
out.append(sklearn_op_to_relay(mod, inexpr, dshape, dtype, cols))
127-
127+
out.append(sklearn_op_to_relay(mod, inexpr, dshape, dtype, func_name, cols))
128+
128129
return _op.concatenate(out, axis=1)
129130

131+
132+
def _InverseLabelTransformer(op, inexpr, dshape, dtype, columns=None):
133+
"""
134+
Identity transformation of the label data. The conversion to string happens in runtime.
135+
"""
136+
return _op.copy(inexpr)
137+
138+
139+
def _RobustOrdinalEncoder(op, inexpr, dshape, dtype, columns=None):
140+
"""
141+
Sagemaker-Scikit-Learn-Extension Transformer:
142+
Encode categorical features as an integer array additional feature of handling unseen values.
143+
The input to this transformer should be an array-like of integers or strings, denoting the
144+
values taken on by categorical (discrete) features. The features are converted to ordinal
145+
integers. This results in a single column of integers (0 to n_categories - 1) per feature.
146+
"""
147+
if columns:
148+
column_indices = _op.const(columns)
149+
inexpr = _op.take(inexpr, indices=column_indices, axis=1)
150+
151+
num_cat = len(op.categories_)
152+
cols = _op.split(inexpr, num_cat, axis=1)
153+
154+
out = []
155+
for i in range(num_cat):
156+
category = op.categories_[i]
157+
cat_tensor = _op.const(np.array(category, dtype=dtype))
158+
tiled_col = _op.tile(cols[i], (1, len(category)))
159+
one_hot_mask = _op.equal(tiled_col, cat_tensor)
160+
one_hot = _op.cast(one_hot_mask, dtype)
161+
162+
offset = _op.const(np.arange(-1, len(category) - 1, dtype=dtype))
163+
zeros = _op.full_like(one_hot, _op.const(0, dtype=dtype))
164+
ordinal_col = _op.where(one_hot_mask, _op.add(one_hot, offset), zeros)
165+
ordinal = _op.expand_dims(_op.sum(ordinal_col, axis=1), -1)
166+
167+
seen_mask = _op.cast(_op.sum(one_hot, axis=1), dtype="bool")
168+
seen_mask = _op.expand_dims(seen_mask, -1)
169+
extra_class = _op.full_like(ordinal, _op.const(len(category), dtype=dtype))
170+
robust_ordinal = _op.where(seen_mask, ordinal, extra_class)
171+
out.append(robust_ordinal)
172+
173+
ret = _op.concatenate(out, axis=1)
174+
return ret
175+
176+
177+
def _RobustLabelEncoder(op, inexpr, dshape, dtype, columns=None):
178+
"""
179+
Sagemaker-Scikit-Learn-Extension Transformer:
180+
Encode target labels with value between 0 and n_classes-1.
181+
"""
182+
if columns:
183+
column_indices = _op.const(columns)
184+
inexpr = _op.take(inexpr, indices=column_indices, axis=1)
185+
186+
class_mask = []
187+
for i in range(len(op.classes_)):
188+
val = (
189+
_op.const(i, dtype) if is_inverse else _op.const(np.array(op.classes_[i], dtype), dtype)
190+
)
191+
class_mask.append(_op.equal(inexpr, val))
192+
for i in range(len(op.classes_)):
193+
if is_inverse:
194+
label_mask = _op.full_like(
195+
inexpr, _op.const(np.array(op.classes_[i], dtype), dtype=dtype)
196+
)
197+
else:
198+
label_mask = _op.full_like(inexpr, _op.const(i, dtype=dtype))
199+
200+
if i == 0:
201+
out = _op.where(class_mask[i], label_mask, inexpr)
202+
continue
203+
out = _op.where(class_mask[i], label_mask, out)
204+
205+
if op.fill_unseen_labels:
206+
unseen_mask = class_mask[0]
207+
for mask in class_mask[1:]:
208+
unseen_mask = _op.logical_or(unseen_mask, mask)
209+
unseen_mask = _op.logical_not(unseen_mask)
210+
unseen_label = (
211+
_op.const(-1, dtype=dtype)
212+
if is_inverse
213+
else _op.const(np.array(len(op.classes_)), dtype=dtype)
214+
)
215+
label_mask = _op.full_like(inexpr, unseen_label)
216+
out = _op.where(unseen_mask, label_mask, out)
217+
218+
return out
219+
220+
221+
def _NALabelEncoder(op, inexpr, dshape, dtype, columns=None):
222+
"""
223+
Sagemaker-Scikit-Learn-Extension Transformer:
224+
Encoder for transforming labels to NA values which encode all non-float and non-finite values
225+
as NA values.
226+
"""
227+
if columns:
228+
column_indices = _op.const(columns)
229+
inexpr = _op.take(inexpr, indices=column_indices, axis=1)
230+
231+
flattened_inexpr = _op.reshape(inexpr, newshape=(-1, 1))
232+
# Hardcoded flattened shape to be (?, 1)
233+
flattened_dshape = (relay.Any(), 1)
234+
ri_out = _RobustImputer(op.model_, flattened_inexpr, flattened_dshape, dtype)
235+
ret = _op.reshape(ri_out, newshape=-1)
236+
return ret
237+
238+
239+
def _RobustStandardScaler(op, inexpr, dshape, dtype, columns=None):
240+
"""
241+
Sagemaker-Scikit-Learn-Extension Transformer:
242+
Standardize features by removing the mean and scaling to unit variance.
243+
"""
244+
if columns:
245+
column_indices = _op.const(columns)
246+
inexpr = _op.take(inexpr, indices=column_indices, axis=1)
247+
248+
scaler = op.scaler_
249+
ret = _op.subtract(inexpr, _op.const(np.array(scaler.mean_, dtype), dtype))
250+
ret = _op.divide(ret, _op.const(np.array(scaler.scale_, dtype), dtype))
251+
return ret
252+
253+
254+
def _KBinsDiscretizer(op, inexpr, dshape, dtype, columns=None):
255+
"""
256+
Scikit-Learn Transformer:
257+
Bin continuous data into intervals.
258+
"""
259+
if columns:
260+
column_indices = _op.const(columns)
261+
inexpr = _op.take(inexpr, indices=column_indices, axis=1)
262+
263+
bin_edges = np.transpose(np.vstack(op.bin_edges_))
264+
out = _op.full_like(inexpr, _op.const(0, dtype=dtype))
265+
266+
for i in range(1, len(bin_edges) - 1):
267+
indices_mask = _op.full_like(inexpr, _op.const(i, dtype=dtype))
268+
bin_edge = _op.const(bin_edges[i])
269+
bin_mask = _op.greater_equal(inexpr, bin_edge)
270+
out = _op.where(bin_mask, indices_mask, out)
271+
272+
return out
273+
274+
275+
def _TfidfVectorizer(op, inexpr, dshape, dtype, columns=None):
276+
"""
277+
Scikit-Learn Transformer:
278+
Transform a count matrix to a normalized tf or tf-idf representation.
279+
"""
280+
if op.use_idf:
281+
idf = _op.const(np.array(op.idf_, dtype=dtype), dtype=dtype)
282+
tfidf = _op.multiply(idf, inexpr)
283+
if op.sublinear_tf:
284+
tfidf = _op.add(tfidf, _op.const(1, dtype))
285+
ret = _op.nn.l2_normalize(tfidf, eps=0.0001, axis=[1])
286+
else:
287+
ret = _op.nn.l2_normalize(inexpr, eps=0.0001, axis=[1])
288+
289+
return ret
290+
291+
292+
def _PCA(op, inexpr, dshape, dtype, columns=None):
293+
"""
294+
Scikit-Learn Transformer:
295+
PCA transformation with existing eigen vector.
296+
"""
297+
eigvec = _op.const(np.array(op.components_, dtype))
298+
ret = _op.nn.dense(inexpr, eigvec)
299+
return ret
300+
301+
130302
_convert_map = {
131-
'ColumnTransformer':_ColumnTransformer,
132-
'SimpleImputer': _SimpleImputer,
133-
'RobustImputer': _RobustImputer,
134-
'RobustStandardScaler': _RobustStandardScaler,
135-
'ThresholdOneHotEncoder': _ThresholdOneHotEncoder
303+
"ColumnTransformer": {"transform": _ColumnTransformer},
304+
"SimpleImputer": {"transform": _SimpleImputer},
305+
"RobustImputer": {"transform": _RobustImputer},
306+
"RobustStandardScaler": {"transform": _RobustStandardScaler},
307+
"ThresholdOneHotEncoder": {"transform": _ThresholdOneHotEncoder},
308+
"NALabelEncoder": {"transform": _NALabelEncoder, "inverse_transform": _InverseLabelTransformer},
309+
"RobustLabelEncoder": {"inverse_transform": _InverseLabelTransformer},
310+
"RobustOrdinalEncoder": {"transform": _RobustOrdinalEncoder},
311+
"KBinsDiscretizer": {"transform": _KBinsDiscretizer},
312+
"TfidfVectorizer": {"transform": _TfidfVectorizer},
313+
"PCA": {"transform": _PCA},
136314
}
137315

138-
def sklearn_op_to_relay(op, inexpr, dshape, dtype, columns=None):
316+
317+
def sklearn_op_to_relay(op, inexpr, dshape, dtype, func_name, columns=None):
318+
"""
319+
Convert Sklearn Ops to Relay Ops.
320+
"""
139321
classname = type(op).__name__
140-
return _convert_map[classname](op, inexpr, dshape, dtype, columns)
141322

142-
def from_sklearn(model,
143-
shape=None,
144-
dtype="float32",
145-
columns=None):
323+
if classname not in _convert_map:
324+
raise NameError("Model {} not supported in scikit-learn frontend".format(classname))
325+
if func_name not in _convert_map[classname]:
326+
raise NameError(
327+
"Function {} of Model {} not supported in scikit-learn frontend".format(
328+
func_name, classname
329+
)
330+
)
331+
332+
if classname == "ColumnTransformer":
333+
return _convert_map[classname][func_name](op, inexpr, dshape, dtype, func_name, columns)
146334

335+
return _convert_map[classname][func_name](op, inexpr, dshape, dtype, columns)
336+
337+
338+
def from_sklearn(model, shape=None, dtype="float32", func_name="transform", columns=None):
339+
"""
340+
Import scikit-learn model to Relay.
341+
"""
147342
try:
148343
import sklearn
149344
except ImportError as e:
150-
raise ImportError(
151-
"Unable to import scikit-learn which is required {}".format(e))
152-
153-
inexpr = _expr.var('input', shape=shape, dtype=dtype)
154-
outexpr = sklearn_op_to_relay(model, inexpr, shape, dtype, columns)
345+
raise ImportError("Unable to import scikit-learn which is required {}".format(e))
346+
347+
inexpr = _expr.var("input", shape=shape, dtype=dtype)
348+
outexpr = sklearn_op_to_relay(model, inexpr, shape, dtype, func_name, columns)
155349

156350
func = _function.Function(analysis.free_vars(outexpr), outexpr)
157351
return IRModule.from_expr(func), []
158352

159-
def from_auto_ml(model,
160-
shape=None,
161-
dtype="float32"):
162353

354+
def from_auto_ml(model, shape=None, dtype="float32", func_name="transform"):
355+
"""
356+
Import scikit-learn model to Relay.
357+
"""
163358
try:
164359
import sklearn
165360
except ImportError as e:
166361
raise ImportError(
167362
"Unable to import scikit-learn which is required {}".format(e))
168363

169-
outexpr = _expr.var('input', shape=shape, dtype=dtype)
170-
for _, transformer in model.feature_transformer.steps:
171-
outexpr = sklearn_op_to_relay(transformer, outexpr, shape, dtype, None)
364+
outexpr = _expr.var("input", shape=shape, dtype=dtype)
365+
366+
if func_name == "transform":
367+
for _, transformer in model.feature_transformer.steps:
368+
outexpr = sklearn_op_to_relay(transformer, outexpr, shape, dtype, func_name, None)
369+
else:
370+
transformer = model.target_transformer
371+
outexpr = sklearn_op_to_relay(transformer, outexpr, shape, dtype, func_name, None)
172372

173373
func = _function.Function(analysis.free_vars(outexpr), outexpr)
174374
return IRModule.from_expr(func), []

python/tvm/relay/op/_tensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -275,4 +275,4 @@ def elemwise_shape_func(attrs, inputs, _):
275275
register_shape_func("isnan", False, elemwise_shape_func)
276276
register_shape_func("isinf", False, elemwise_shape_func)
277277
register_shape_func("where", False, elemwise_shape_func)
278-
278+
register_shape_func("copy", False, elemwise_shape_func)

0 commit comments

Comments
 (0)