1616from jax import random
1717from numpyro .infer import MCMC , NUTS , Predictive , init_to_sample
1818from sklearn .ensemble import RandomForestRegressor
19- from sklearn .preprocessing import OneHotEncoder
19+ from sklearn .preprocessing import OneHotEncoder , OrdinalEncoder
2020from typing_extensions import Self
2121
2222
@@ -168,32 +168,8 @@ def __init__(
168168 )
169169
170170 # preprocess data
171- self .data = self ._preprocess (
172- data = self .raw_data ,
173- date_column = self .date_column ,
174- season_start_month = self .season ["start_month" ],
175- season_start_day = self .season ["start_day" ],
176- )
177-
178- # do the indexing
179- self .n_group_levels = [
180- self .data .select (pl .col (group ).unique ()).height
181- for group in ["season" , "geography" , "season_geo" ]
182- ]
183-
184- # initialize MCMC. `None` is a placeholder indicating fitting has not occurred
185- self .mcmc = None
186-
187- @classmethod
188- def _preprocess (
189- cls ,
190- data : pl .DataFrame ,
191- date_column : str ,
192- season_start_month : int ,
193- season_start_day : int ,
194- ) -> pl .DataFrame :
195- out = (
196- data
171+ self .data = (
172+ self .raw_data
197173 # prepare observation data
198174 .rename ({"sample_size" : "N_tot" })
199175 .with_columns (N_vax = (pl .col ("N_tot" ) * pl .col ("estimate" )).round (0 ))
@@ -202,44 +178,24 @@ def _preprocess(
202178 season_geo = pl .concat_str (["season" , "geography" ], separator = "_" )
203179 )
204180 .with_columns (
205- t = cls ._days_in_season (
181+ t = self ._days_in_season (
206182 pl .col (date_column ),
207- season_start_month = season_start_month ,
208- season_start_day = season_start_day ,
183+ season_start_month = self . season [ "start_month" ] ,
184+ season_start_day = self . season [ "start_day" ] ,
209185 )
210186 / 365
211187 )
212188 )
213189
214- # add the indices
215- out = cls ._index (out , groups = ["season" , "geography" , "season_geo" ])
216-
217- return out
218-
219- @staticmethod
220- def _index (data : pl .DataFrame , groups : list [str ]) -> pl .DataFrame :
221- """
222- For each column in `groups` (e.g., `"season"`), add a new column `"{group}_idx"`
223- (e.g., `"season_idx"`) that has the values in the original column replaced by
224- integer indices.
225-
226- Args:
227- data: dataframe
228- groups: names of columns
229-
230- Returns: dataframe with additional columns like `"{group}_idx"`
231- """
232- for group in groups :
233- unique_values = (
234- data .select (pl .col (group ).unique ().sort ()).get_column (group ).to_list ()
235- )
236- indices = list (range (len (unique_values )))
237- replace_map = {value : index for value , index in zip (unique_values , indices )}
238- data = data .with_columns (
239- pl .col (group ).replace_strict (replace_map ).alias (f"{ group } _idx" )
240- )
190+ # set up encoder
191+ self .groups = ("season" , "geography" , "season_geo" )
192+ self .enc = OrdinalEncoder (dtype = np .int64 ).fit (
193+ self .data .select (self .groups ).to_numpy ()
194+ )
195+ self .n_group_levels = [len (x ) for x in self .enc .categories_ ]
241196
242- return data
197+ # initialize MCMC. `None` is a placeholder indicating fitting has not occurred
198+ self .mcmc = None
243199
244200 @staticmethod
245201 def _days_in_season (
@@ -280,12 +236,8 @@ def model(self, data: pl.DataFrame):
280236 t = jnp .array (data ["t" ]),
281237 # jax runs into a problem if you don't specify this type
282238 N_tot = jnp .array (data ["N_tot" ], dtype = jnp .int32 ),
283- groups = jnp .array (
284- data .select (
285- [f"{ group } _idx" for group in ["season" , "geography" , "season_geo" ]]
286- )
287- ),
288- n_groups = 3 ,
239+ groups = jnp .array (self .enc .transform (data .select (self .groups ).to_numpy ())),
240+ n_groups = len (self .groups ),
289241 n_group_levels = self .n_group_levels ,
290242 ** self .model_params ,
291243 )
@@ -505,7 +457,7 @@ def _month_in_season(self, date: datetime.date) -> int:
505457 return (year - ssy ) * 12 + (date .month - self .season ["start_month" ])
506458
507459 def fit (self ) -> Self :
508- self .enc = Encoder ().fit (self .data )
460+ self .enc = RFEncoder ().fit (self .data )
509461
510462 self .X_features = ["season" , "geography" ] + [
511463 str (t )
@@ -575,7 +527,7 @@ def _postprocess(
575527 )
576528
577529
578- class Encoder :
530+ class RFEncoder :
579531 def __init__ (self , categorical_features : tuple = ("season" , "geography" )):
580532 self .categorical_features = categorical_features
581533 self .enc = OneHotEncoder (sparse_output = False )
0 commit comments