Skip to content

Commit f4d5400

Browse files
committed
fix label encoder
1 parent 54f0967 commit f4d5400

File tree

2 files changed

+98
-34
lines changed

2 files changed

+98
-34
lines changed
Lines changed: 97 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import os
3-
import os.path as osp
3+
import pickle
44

55
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, OrdinalEncoder
66

@@ -12,7 +12,9 @@
1212

1313

1414
class LabelEncode(Preprocessor):
15-
def __init__(self, encoder_type="le", save_folder=None, **kwargs):
15+
def __init__(
16+
self, encoder_type="le", pickle_path=None, engine: str = "pandas", **kwargs
17+
):
1618
super().__init__(**kwargs)
1719

1820
assert encoder_type in [
@@ -22,48 +24,110 @@ def __init__(self, encoder_type="le", save_folder=None, **kwargs):
2224
], "Encoder type not supported"
2325

2426
self.encoder_type = encoder_type
25-
self.save_folder = save_folder
26-
self.mapping_dict = {}
27-
28-
if self.encoder_type == "le":
29-
self.encoder = LabelEncoder()
30-
elif self.encoder_type == "onehot":
31-
self.encoder = OneHotEncoder()
27+
self.pickle_path = pickle_path
28+
self.engine = engine
29+
if self.engine == "polars":
30+
import polars as pl
31+
32+
if self.pickle_path is not None:
33+
with open(self.pickle_path, "rb") as fb:
34+
config = pickle.load(fb)
35+
self.column_names = config["column_names"]
36+
self.encoder_type = config["encoder_type"]
37+
self.engine = config["engine"]
38+
self.encoders = config["encoders"]
39+
self.log(f"Loaded mapping dict from {self.pickle_path}")
3240
else:
33-
self.encoder = OrdinalEncoder()
41+
self.encoders = {}
42+
if self.encoder_type == "le":
43+
encoder = LabelEncoder()
44+
elif self.encoder_type == "onehot":
45+
encoder = OneHotEncoder()
46+
else:
47+
encoder = OrdinalEncoder()
48+
49+
for column in self.column_names:
50+
self.encoders[column] = encoder
3451

3552
@classmethod
36-
def from_json(cls, json_path: str):
37-
return cls(json_path=json_path, encoder_type="json_mapping")
38-
39-
def create_mapping_dict(self, column_name):
40-
le_name_mapping = dict(
41-
zip(
42-
self.encoder.classes_,
43-
[int(i) for i in self.encoder.transform(self.encoder.classes_)],
53+
def from_pickle(cls, pickle_path: str):
54+
return cls(pickle_path=pickle_path)
55+
56+
def save_pickle(self, pickle_path: str):
57+
with open(pickle_path, "wb") as fb:
58+
pickle.dump(
59+
{
60+
"column_names": self.column_names,
61+
"encoder_type": self.encoder_type,
62+
"engine": self.engine,
63+
"encoders": self.encoders,
64+
},
65+
fb,
4466
)
45-
)
46-
if self.save_folder is not None:
47-
os.makedirs(self.save_folder, exist_ok=True)
48-
json.dump(
49-
le_name_mapping,
50-
open(osp.join(self.save_folder, column_name + ".json"), "w"),
51-
indent=4,
67+
self.log(f"Saved encoder to {pickle_path}")
68+
69+
def save_json(self, json_path: str):
70+
os.makedirs(os.path.dirname(json_path), exist_ok=True)
71+
mapping_dict = {}
72+
for column_name in self.column_names:
73+
class_mapping = dict(
74+
zip(
75+
self.encoders[column_name].classes_,
76+
[
77+
int(i)
78+
for i in self.encoders[column_name].transform(
79+
self.encoders[column_name].classes_
80+
)
81+
],
82+
)
5283
)
53-
return le_name_mapping
84+
mapping_dict[column_name] = class_mapping
85+
with open(json_path, "w") as fb:
86+
json.dump(mapping_dict, fb, indent=4)
87+
self.log(f"Saved mapping dict to {json_path}")
5488

5589
def encode_corpus(self, df):
5690
for column_name in self.column_names:
57-
df[column_name] = self.encoder.fit_transform(df[column_name].values).copy()
58-
mapping_dict = self.create_mapping_dict(column_name)
59-
self.mapping_dict[column_name] = mapping_dict
91+
encoder = self.encoders[column_name]
92+
if self.engine == "pandas":
93+
df[column_name] = encoder.fit_transform(df[column_name].values).copy()
94+
elif self.engine == "polars":
95+
import polars as pl
96+
97+
encoder.fit_transform(df[column_name].to_numpy())
98+
le_name_mapping = dict(
99+
zip(
100+
encoder.classes_,
101+
[int(i) for i in encoder.transform(encoder.classes_)],
102+
)
103+
)
104+
df = df.with_columns(
105+
pl.col(column_name).replace_strict(
106+
le_name_mapping, return_dtype=pl.Int32, default=None
107+
)
108+
)
109+
60110
return df
61111

62112
def encode_query(self, df):
63113
for column_name in self.column_names:
64-
df[column_name] = self.apply(
65-
df[column_name], lambda x: self.mapping_dict[column_name].get(x, -1)
66-
).copy()
114+
encoder = self.encoders[column_name]
115+
if self.engine == "pandas":
116+
df[column_name] = encoder.transform(df[column_name].values).copy()
117+
elif self.engine == "polars":
118+
import polars as pl
119+
120+
le_name_mapping = dict(
121+
zip(
122+
encoder.classes_,
123+
[int(i) for i in encoder.transform(encoder.classes_)],
124+
)
125+
)
126+
df = df.with_columns(
127+
pl.col(column_name).replace_strict(
128+
le_name_mapping, return_dtype=pl.Int32, default=None
129+
)
130+
)
67131
return df
68132

69133
def run(self, df):
@@ -75,7 +139,7 @@ def run(self, df):
75139
level=LoggerObserver.WARN,
76140
)
77141
self.column_names = [col for col, dt in df.dtypes.items() if dt == object]
78-
self.encode_corpus(df)
142+
df = self.encode_corpus(df)
79143

80144
self.log(f"Label-encoded columns: {self.column_names}")
81145
return df

theseus/nlp/retrieval/models/tf_idf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def identity_tokenizer(text):
1414
class TFIDFEncoder(BaseRetrieval):
1515
def __init__(
1616
self,
17-
min_df: int = 0,
17+
min_df: int = 0.0,
1818
max_df: int = 1.0,
1919
model_path: str = None,
2020
ngram_range: Tuple[int] = (1, 1),

0 commit comments

Comments
 (0)