Skip to content

Commit b598083

Browse files
committed
LPL: Use an encoder
1 parent a566454 commit b598083

1 file changed

Lines changed: 18 additions & 66 deletions

File tree

iup/__init__.py

Lines changed: 18 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from jax import random
1717
from numpyro.infer import MCMC, NUTS, Predictive, init_to_sample
1818
from sklearn.ensemble import RandomForestRegressor
19-
from sklearn.preprocessing import OneHotEncoder
19+
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
2020
from typing_extensions import Self
2121

2222

@@ -168,32 +168,8 @@ def __init__(
168168
)
169169

170170
# preprocess data
171-
self.data = self._preprocess(
172-
data=self.raw_data,
173-
date_column=self.date_column,
174-
season_start_month=self.season["start_month"],
175-
season_start_day=self.season["start_day"],
176-
)
177-
178-
# do the indexing
179-
self.n_group_levels = [
180-
self.data.select(pl.col(group).unique()).height
181-
for group in ["season", "geography", "season_geo"]
182-
]
183-
184-
# initialize MCMC. `None` is a placeholder indicating fitting has not occurred
185-
self.mcmc = None
186-
187-
@classmethod
188-
def _preprocess(
189-
cls,
190-
data: pl.DataFrame,
191-
date_column: str,
192-
season_start_month: int,
193-
season_start_day: int,
194-
) -> pl.DataFrame:
195-
out = (
196-
data
171+
self.data = (
172+
self.raw_data
197173
# prepare observation data
198174
.rename({"sample_size": "N_tot"})
199175
.with_columns(N_vax=(pl.col("N_tot") * pl.col("estimate")).round(0))
@@ -202,44 +178,24 @@ def _preprocess(
202178
season_geo=pl.concat_str(["season", "geography"], separator="_")
203179
)
204180
.with_columns(
205-
t=cls._days_in_season(
181+
t=self._days_in_season(
206182
pl.col(date_column),
207-
season_start_month=season_start_month,
208-
season_start_day=season_start_day,
183+
season_start_month=self.season["start_month"],
184+
season_start_day=self.season["start_day"],
209185
)
210186
/ 365
211187
)
212188
)
213189

214-
# add the indices
215-
out = cls._index(out, groups=["season", "geography", "season_geo"])
216-
217-
return out
218-
219-
@staticmethod
220-
def _index(data: pl.DataFrame, groups: list[str]) -> pl.DataFrame:
221-
"""
222-
For each column in `groups` (e.g., `"season"`), add a new column `"{group}_idx"`
223-
(e.g., `"season_idx"`) that has the values in the original column replaced by
224-
integer indices.
225-
226-
Args:
227-
data: dataframe
228-
groups: names of columns
229-
230-
Returns: dataframe with additional columns like `"{group}_idx"`
231-
"""
232-
for group in groups:
233-
unique_values = (
234-
data.select(pl.col(group).unique().sort()).get_column(group).to_list()
235-
)
236-
indices = list(range(len(unique_values)))
237-
replace_map = {value: index for value, index in zip(unique_values, indices)}
238-
data = data.with_columns(
239-
pl.col(group).replace_strict(replace_map).alias(f"{group}_idx")
240-
)
190+
# set up encoder
191+
self.groups = ("season", "geography", "season_geo")
192+
self.enc = OrdinalEncoder(dtype=np.int64).fit(
193+
self.data.select(self.groups).to_numpy()
194+
)
195+
self.n_group_levels = [len(x) for x in self.enc.categories_]
241196

242-
return data
197+
# initialize MCMC. `None` is a placeholder indicating fitting has not occurred
198+
self.mcmc = None
243199

244200
@staticmethod
245201
def _days_in_season(
@@ -280,12 +236,8 @@ def model(self, data: pl.DataFrame):
280236
t=jnp.array(data["t"]),
281237
# jax runs into a problem if you don't specify this type
282238
N_tot=jnp.array(data["N_tot"], dtype=jnp.int32),
283-
groups=jnp.array(
284-
data.select(
285-
[f"{group}_idx" for group in ["season", "geography", "season_geo"]]
286-
)
287-
),
288-
n_groups=3,
239+
groups=jnp.array(self.enc.transform(data.select(self.groups).to_numpy())),
240+
n_groups=len(self.groups),
289241
n_group_levels=self.n_group_levels,
290242
**self.model_params,
291243
)
@@ -505,7 +457,7 @@ def _month_in_season(self, date: datetime.date) -> int:
505457
return (year - ssy) * 12 + (date.month - self.season["start_month"])
506458

507459
def fit(self) -> Self:
508-
self.enc = Encoder().fit(self.data)
460+
self.enc = RFEncoder().fit(self.data)
509461

510462
self.X_features = ["season", "geography"] + [
511463
str(t)
@@ -575,7 +527,7 @@ def _postprocess(
575527
)
576528

577529

578-
class Encoder:
530+
class RFEncoder:
579531
def __init__(self, categorical_features: tuple = ("season", "geography")):
580532
self.categorical_features = categorical_features
581533
self.enc = OneHotEncoder(sparse_output=False)

0 commit comments

Comments
 (0)