Skip to content

Commit b6abb31

Browse files
Merge pull request #372 from guillaume-vignal/master
Enable compatibility with new version of category-encoders
2 parents 175b1ab + b9b9f03 commit b6abb31

File tree

6 files changed

+88
-54
lines changed

6 files changed

+88
-54
lines changed

requirements.dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
pip==21.3.1
22
dash==2.3.1
33
catboost==0.26.1
4-
category-encoders==2.1.0
4+
category-encoders==2.2.2
55
dash-bootstrap-components==1.1.0
66
dash-core-components==2.0.0
77
dash-daq==0.5.0

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
extras['lightgbm'] = ['lightgbm>=2.3.0']
5353
extras['catboost'] = ['catboost>=0.21']
5454
extras['scikit-learn'] = ['scikit-learn>=0.23.0']
55-
extras['category_encoders'] = ['category_encoders==2.2.2']
55+
extras['category_encoders'] = ['category_encoders>=2.2.2']
5656
extras['acv'] = ['acv-exp==1.1.2']
5757
extras['lime'] = ['lime']
5858

shapash/utils/category_encoder_backend.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pandas as pd
66
import numpy as np
7+
import category_encoders as ce
78

89
category_encoder_onehot = "<class 'category_encoders.one_hot.OneHotEncoder'>"
910
category_encoder_ordinal = "<class 'category_encoders.ordinal.OrdinalEncoder'>"
@@ -54,8 +55,12 @@ def inv_transform_ce(x_in, encoding):
5455
rst = inv_transform_ordinal(x, encoding.ordinal_encoder.mapping)
5556

5657
elif str(type(encoding)) == category_encoder_binary:
57-
x = reverse_basen(x_in, encoding.base_n_encoder)
58-
rst = inv_transform_ordinal(x, encoding.base_n_encoder.ordinal_encoder.mapping)
58+
if ce.__version__ <= '2.2.2':
59+
x = reverse_basen(x_in, encoding.base_n_encoder)
60+
rst = inv_transform_ordinal(x, encoding.base_n_encoder.ordinal_encoder.mapping)
61+
else:
62+
x = reverse_basen(x_in, encoding)
63+
rst = inv_transform_ordinal(x, encoding.ordinal_encoder.mapping)
5964

6065
elif str(type(encoding)) == category_encoder_targetencoder:
6166
rst = inv_transform_target(x_in, encoding)
@@ -106,8 +111,8 @@ def inv_transform_target(x_in, enc_target):
106111
rst_target = pd.concat([reverse_target, mapping_ordinal], axis=1, join='inner').fillna(value='NaN')
107112
aggregate = rst_target.groupby(1)[0].apply(lambda x: ' / '.join(map(str, x)))
108113
if aggregate.shape[0] != rst_target.shape[0]:
109-
raise Exception('Multiple label found for the same value in TargetEncoder on col '+str(name_target) +'.')
110-
#print("Warning in inverse TargetEncoder - col " + str(name_target) + ": Multiple label for the same value, "
114+
raise Exception('Multiple label found for the same value in TargetEncoder on col '+str(name_target) + '.')
115+
# print("Warning in inverse TargetEncoder - col " + str(name_target) + ": Multiple label for the same value, "
111116
# "each label will be separate using : / ")
112117

113118
transco = {'col': name_target,
@@ -138,7 +143,10 @@ def inv_transform_ordinal(x_in, encoding):
138143
if not col_name in x_in.columns:
139144
raise Exception(f'Columns {col_name} not in dataframe.')
140145
column_mapping = switch.get('mapping')
141-
inverse = pd.Series(data=column_mapping.index, index=column_mapping.values)
146+
if isinstance(column_mapping, dict):
147+
inverse = pd.Series(data=column_mapping.keys(), index=column_mapping.values())
148+
else:
149+
inverse = pd.Series(data=column_mapping.index, index=column_mapping.values)
142150
x_in[col_name] = x_in[col_name].map(inverse).astype(switch.get('data_type'))
143151
return x_in
144152

@@ -201,7 +209,7 @@ def calc_inv_contrib_ce(x_contrib, encoding, agg_columns):
201209
The aggregate contributions depending on which processing is apply.
202210
"""
203211
if str(type(encoding)) in dummies_category_encoder:
204-
if str(type(encoding)) in category_encoder_binary:
212+
if str(type(encoding)) in category_encoder_binary and ce.__version__ <= '2.2.2':
205213
encoding = encoding.base_n_encoder
206214
drop_col = []
207215
for switch in encoding.mapping:
@@ -218,6 +226,7 @@ def calc_inv_contrib_ce(x_contrib, encoding, agg_columns):
218226
else:
219227
return x_contrib
220228

229+
221230
def transform_ce(x_in, encoding):
222231
"""
223232
Choose and apply the transformation for the given encoding.
@@ -242,14 +251,15 @@ def transform_ce(x_in, encoding):
242251
if str(type(encoding)) in encoder:
243252
rst = encoding.transform(x_in)
244253

245-
elif isinstance(encoding,list):
254+
elif isinstance(encoding, list):
246255
rst = transform_ordinal(x_in, encoding)
247256

248257
else:
249258
raise Exception(f"{encoding.__class__.__name__} not supported, no preprocessing done.")
250259

251260
return rst
252261

262+
253263
def transform_ordinal(x_in, encoding):
254264
"""
255265
Transformation based on ordinal category encoder.
@@ -271,7 +281,10 @@ def transform_ordinal(x_in, encoding):
271281
if not col_name in x_in.columns:
272282
raise Exception(f'Columns {col_name} not in dataframe.')
273283
column_mapping = switch.get('mapping')
274-
transform = pd.Series(data=column_mapping.values, index=column_mapping.index)
284+
if isinstance(column_mapping, dict):
285+
transform = pd.Series(data=column_mapping.values(), index=column_mapping.keys())
286+
else:
287+
transform = pd.Series(data=column_mapping.values, index=column_mapping.index)
275288
x_in[col_name] = x_in[col_name].map(transform).astype(switch.get('mapping').values.dtype)
276289
return x_in
277290

@@ -294,7 +307,10 @@ def get_col_mapping_ce(encoder):
294307
category_encoder_targetencoder]:
295308
encoder_mapping = encoder.mapping
296309
elif str(type(encoder)) == category_encoder_binary:
297-
encoder_mapping = encoder.base_n_encoder.mapping
310+
if ce.__version__ <= '2.2.2':
311+
encoder_mapping = encoder.base_n_encoder.mapping
312+
else:
313+
encoder_mapping = encoder.mapping
298314
else:
299315
raise NotImplementedError(f"{encoder} not supported.")
300316

shapash/utils/columntransformer_backend.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def inv_transform_sklearn_in_ct(x_in, init, name_encoding, col_encoding, ct_enco
179179
init += nb_col
180180
return frame, init
181181

182+
182183
def calc_inv_contrib_ct(x_contrib, encoding, agg_columns):
183184
"""
184185
Reversed contribution when ColumnTransformer is used.
@@ -226,7 +227,10 @@ def calc_inv_contrib_ct(x_contrib, encoding, agg_columns):
226227
if str(type(ct_encoding)) == sklearn_onehot:
227228
col_origin = ct_encoding.categories_[i_enc]
228229
elif str(type(ct_encoding)) == category_encoder_binary:
229-
col_origin = ct_encoding.base_n_encoder.mapping[i_enc].get('mapping').columns.tolist()
230+
try:
231+
col_origin = ct_encoding.base_n_encoder.mapping[i_enc].get('mapping').columns.tolist()
232+
except:
233+
col_origin = ct_encoding.mapping[i_enc].get('mapping').columns.tolist()
230234
else:
231235
col_origin = ct_encoding.mapping[i_enc].get('mapping').columns.tolist()
232236
nb_col = len(col_origin)
@@ -292,8 +296,8 @@ def transform_ct(x_in, model, encoding):
292296

293297
elif str(type(model)) in other_model:
294298
rst = pd.DataFrame(encoding.transform(x_in),
295-
columns=extract_features_model(model, dict_model_feature[str(type(model))]),
296-
index=x_in.index)
299+
columns=extract_features_model(model, dict_model_feature[str(type(model))]),
300+
index=x_in.index)
297301
else:
298302
raise ValueError("Model specified isn't supported by Shapash.")
299303

@@ -305,6 +309,7 @@ def transform_ct(x_in, model, encoding):
305309

306310
return rst
307311

312+
308313
def get_names(name, trans, column, column_transformer):
309314
"""
310315
Allow to extract features names from one encoder of the ColumnTransformer.
@@ -347,6 +352,7 @@ def get_names(name, trans, column, column_transformer):
347352

348353
return [name + "__" + f for f in trans.get_feature_names()]
349354

355+
350356
def get_feature_names(column_transformer):
351357
"""
352358
Allow to extract all features names from encoders of the ColumnTransformer once it has been applied.
@@ -370,6 +376,7 @@ def get_feature_names(column_transformer):
370376

371377
return feature_names
372378

379+
373380
def get_list_features_names(list_preprocessing, columns_dict):
374381
"""
375382
Allow to extract all features names from encoders when a list of preprocessing is uesd once it has been applied.

tests/unit_tests/utils/test_category_encoders_backend.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import numpy as np
77
import category_encoders as ce
88
import catboost as cb
9-
import sklearn
109
import lightgbm
1110
import xgboost
1211
from shapash.utils.transform import inverse_transform, apply_preprocessing, get_col_mapping_ce
@@ -87,7 +86,6 @@ def test_inverse_transform_2(self):
8786

8887
pd.testing.assert_frame_equal(expected, original)
8988

90-
9189
def test_inverse_transform_3(self):
9290
"""
9391
Test target encoding
@@ -404,15 +402,15 @@ def test_inverse_transform_26(self):
404402
'BaseN1': ['M', 'N', 'N'], 'BaseN2': ['O', 'P', 'ZZ'],
405403
'Target1': ['Q', 'R', 'R'], 'Target2': ['S', 'T', 'ZZ'],
406404
'other': ['other', '123', np.nan]},
407-
index=['index1', 'index2', 'index3'])
405+
index=['index1', 'index2', 'index3'])
408406

409407
expected = pd.DataFrame({'Onehot1': ['A', 'B', 'A'], 'Onehot2': ['C', 'D', 'missing'],
410408
'Binary1': ['E', 'F', 'F'], 'Binary2': ['G', 'H', 'missing'],
411409
'Ordinal1': ['I', 'J', 'J'], 'Ordinal2': ['K', 'L', 'missing'],
412410
'BaseN1': ['M', 'N', 'N'], 'BaseN2': ['O', 'P', np.nan],
413411
'Target1': ['Q', 'R', 'R'], 'Target2': ['S', 'T', 'NaN'],
414412
'other': ['other', '123', np.nan]},
415-
index=['index1', 'index2', 'index3'])
413+
index=['index1', 'index2', 'index3'])
416414

417415
y = pd.DataFrame(data=[0, 1, 0, 0], columns=['y'])
418416

@@ -668,7 +666,7 @@ def test_get_col_mapping_ce_1(self):
668666
y = pd.DataFrame(data=[0, 1, 1], columns=['y'])
669667

670668
enc = ce.TargetEncoder(cols=['city', 'state'])
671-
test_encoded = pd.DataFrame(enc.fit_transform(test, y))
669+
enc.fit(test, y)
672670

673671
mapping = get_col_mapping_ce(enc)
674672
expected_mapping = {'city': ['city'], 'state': ['state']}
@@ -685,7 +683,7 @@ def test_get_col_mapping_ce_2(self):
685683
y = pd.DataFrame(data=[0, 1, 1], columns=['y'])
686684

687685
enc = ce.OrdinalEncoder(handle_missing='value', handle_unknown='value')
688-
test_encoded = pd.DataFrame(enc.fit_transform(test, y))
686+
enc.fit(test, y)
689687

690688
mapping = get_col_mapping_ce(enc)
691689
expected_mapping = {'city': ['city'], 'state': ['state'], 'other': ['other']}
@@ -702,7 +700,7 @@ def test_get_col_mapping_ce_3(self):
702700
y = pd.DataFrame(data=[0, 1, 1], columns=['y'])
703701

704702
enc = ce.BinaryEncoder(cols=['city', 'state'])
705-
test_encoded = pd.DataFrame(enc.fit_transform(test, y))
703+
enc.fit(test, y)
706704

707705
mapping = get_col_mapping_ce(enc)
708706
expected_mapping = {'city': ['city_0', 'city_1'], 'state': ['state_0', 'state_1']}
@@ -719,11 +717,15 @@ def test_get_col_mapping_ce_4(self):
719717
y = pd.DataFrame(data=[0, 1, 1], columns=['y'])
720718

721719
enc = ce.BaseNEncoder(base=2)
722-
test_encoded = pd.DataFrame(enc.fit_transform(test, y))
720+
enc.fit(test, y)
723721

724722
mapping = get_col_mapping_ce(enc)
725-
expected_mapping = {'city': ['city_0', 'city_1', 'city_2'], 'state': ['state_0', 'state_1'],
726-
'other': ['other_0', 'other_1']}
723+
if ce.__version__ <= '2.2.2':
724+
expected_mapping = {'city': ['city_0', 'city_1', 'city_2'], 'state': ['state_0', 'state_1'],
725+
'other': ['other_0', 'other_1']}
726+
else:
727+
expected_mapping = {'city': ['city_0', 'city_1'], 'state': ['state_0', 'state_1'],
728+
'other': ['other_0', 'other_1']}
727729

728730
self.assertDictEqual(mapping, expected_mapping)
729731

@@ -737,7 +739,7 @@ def test_get_col_mapping_ce_5(self):
737739
y = pd.DataFrame(data=[0, 1, 1], columns=['y'])
738740

739741
enc = ce.OneHotEncoder(cols=['city', 'state'], use_cat_names=True)
740-
test_encoded = pd.DataFrame(enc.fit_transform(test, y))
742+
enc.fit(test, y)
741743

742744
mapping = get_col_mapping_ce(enc)
743745
expected_mapping = {'city': ['city_chicago', 'city_paris'], 'state': ['state_US', 'state_FR']}

0 commit comments

Comments
 (0)