11import json
22import os
3- import os . path as osp
3+ import pickle
44
55from sklearn .preprocessing import LabelEncoder , OneHotEncoder , OrdinalEncoder
66
1212
1313
1414class LabelEncode (Preprocessor ):
15- def __init__ (self , encoder_type = "le" , save_folder = None , ** kwargs ):
15+ def __init__ (
16+ self , encoder_type = "le" , pickle_path = None , engine : str = "pandas" , ** kwargs
17+ ):
1618 super ().__init__ (** kwargs )
1719
1820 assert encoder_type in [
@@ -22,48 +24,110 @@ def __init__(self, encoder_type="le", save_folder=None, **kwargs):
2224 ], "Encoder type not supported"
2325
2426 self .encoder_type = encoder_type
25- self .save_folder = save_folder
26- self .mapping_dict = {}
27-
28- if self .encoder_type == "le" :
29- self .encoder = LabelEncoder ()
30- elif self .encoder_type == "onehot" :
31- self .encoder = OneHotEncoder ()
27+ self .pickle_path = pickle_path
28+ self .engine = engine
29+ if self .engine == "polars" :
30+ import polars as pl
31+
32+ if self .pickle_path is not None :
33+ with open (self .pickle_path , "rb" ) as fb :
34+ config = pickle .load (fb )
35+ self .column_names = config ["column_names" ]
36+ self .encoder_type = config ["encoder_type" ]
37+ self .engine = config ["engine" ]
38+ self .encoders = config ["encoders" ]
39+ self .log (f"Loaded mapping dict from { self .pickle_path } " )
3240 else :
33- self .encoder = OrdinalEncoder ()
41+ self .encoders = {}
42+ if self .encoder_type == "le" :
43+ encoder = LabelEncoder ()
44+ elif self .encoder_type == "onehot" :
45+ encoder = OneHotEncoder ()
46+ else :
47+ encoder = OrdinalEncoder ()
48+
49+ for column in self .column_names :
50+ self .encoders [column ] = encoder
3451
3552 @classmethod
36- def from_json (cls , json_path : str ):
37- return cls (json_path = json_path , encoder_type = "json_mapping" )
38-
39- def create_mapping_dict (self , column_name ):
40- le_name_mapping = dict (
41- zip (
42- self .encoder .classes_ ,
43- [int (i ) for i in self .encoder .transform (self .encoder .classes_ )],
53+ def from_pickle (cls , pickle_path : str ):
54+ return cls (pickle_path = pickle_path )
55+
56+ def save_pickle (self , pickle_path : str ):
57+ with open (pickle_path , "wb" ) as fb :
58+ pickle .dump (
59+ {
60+ "column_names" : self .column_names ,
61+ "encoder_type" : self .encoder_type ,
62+ "engine" : self .engine ,
63+ "encoders" : self .encoders ,
64+ },
65+ fb ,
4466 )
45- )
46- if self .save_folder is not None :
47- os .makedirs (self .save_folder , exist_ok = True )
48- json .dump (
49- le_name_mapping ,
50- open (osp .join (self .save_folder , column_name + ".json" ), "w" ),
51- indent = 4 ,
67+ self .log (f"Saved encoder to { pickle_path } " )
68+
69+ def save_json (self , json_path : str ):
70+ os .makedirs (os .path .dirname (json_path ), exist_ok = True )
71+ mapping_dict = {}
72+ for column_name in self .column_names :
73+ class_mapping = dict (
74+ zip (
75+ self .encoders [column_name ].classes_ ,
76+ [
77+ int (i )
78+ for i in self .encoders [column_name ].transform (
79+ self .encoders [column_name ].classes_
80+ )
81+ ],
82+ )
5283 )
53- return le_name_mapping
84+ mapping_dict [column_name ] = class_mapping
85+ with open (json_path , "w" ) as fb :
86+ json .dump (mapping_dict , fb , indent = 4 )
87+ self .log (f"Saved mapping dict to { json_path } " )
5488
5589 def encode_corpus (self , df ):
5690 for column_name in self .column_names :
57- df [column_name ] = self .encoder .fit_transform (df [column_name ].values ).copy ()
58- mapping_dict = self .create_mapping_dict (column_name )
59- self .mapping_dict [column_name ] = mapping_dict
91+ encoder = self .encoders [column_name ]
92+ if self .engine == "pandas" :
93+ df [column_name ] = encoder .fit_transform (df [column_name ].values ).copy ()
94+ elif self .engine == "polars" :
95+ import polars as pl
96+
97+ encoder .fit_transform (df [column_name ].to_numpy ())
98+ le_name_mapping = dict (
99+ zip (
100+ encoder .classes_ ,
101+ [int (i ) for i in encoder .transform (encoder .classes_ )],
102+ )
103+ )
104+ df = df .with_columns (
105+ pl .col (column_name ).replace_strict (
106+ le_name_mapping , return_dtype = pl .Int32 , default = None
107+ )
108+ )
109+
60110 return df
61111
62112 def encode_query (self , df ):
63113 for column_name in self .column_names :
64- df [column_name ] = self .apply (
65- df [column_name ], lambda x : self .mapping_dict [column_name ].get (x , - 1 )
66- ).copy ()
114+ encoder = self .encoders [column_name ]
115+ if self .engine == "pandas" :
116+ df [column_name ] = encoder .transform (df [column_name ].values ).copy ()
117+ elif self .engine == "polars" :
118+ import polars as pl
119+
120+ le_name_mapping = dict (
121+ zip (
122+ encoder .classes_ ,
123+ [int (i ) for i in encoder .transform (encoder .classes_ )],
124+ )
125+ )
126+ df = df .with_columns (
127+ pl .col (column_name ).replace_strict (
128+ le_name_mapping , return_dtype = pl .Int32 , default = None
129+ )
130+ )
67131 return df
68132
69133 def run (self , df ):
@@ -75,7 +139,7 @@ def run(self, df):
75139 level = LoggerObserver .WARN ,
76140 )
77141 self .column_names = [col for col , dt in df .dtypes .items () if dt == object ]
78- self .encode_corpus (df )
142+ df = self .encode_corpus (df )
79143
80144 self .log (f"Label-encoded columns: { self .column_names } " )
81145 return df
0 commit comments