1616from jax import random
1717from numpyro .infer import MCMC , NUTS , init_to_sample
1818from sklearn .ensemble import RandomForestRegressor
19+ from sklearn .impute import KNNImputer
1920from sklearn .preprocessing import OneHotEncoder , OrdinalEncoder
2021from typing_extensions import Self
2122
@@ -427,11 +428,8 @@ def __init__(
427428 self .data = (
428429 data_t .select (["season" , "geography" , "t" , "estimate" ])
429430 .pivot (on = "t" , values = "estimate" , sort_columns = True )
430- # impute zero uptake at start of season
431- .with_columns (pl .coalesce (pl .col ("0" ), 0.0 ))
432- # drop season/geo's with any other missing values
433- .drop_nulls ()
434431 .sort (["season" , "geography" ])
432+ .pipe (self ._impute )
435433 )
436434
437435 self .forecast_season = pl .select (
@@ -445,6 +443,24 @@ def __init__(
445443 ).item ()
446444 self .forecast_month = self ._month_in_season (self .forecast_date )
447445
446+ @staticmethod
447+ def _impute (
448+ df : pl .DataFrame , index_cols : tuple [str , ...] = ("season" , "geography" )
449+ ):
450+ to_impute_df = df .drop (index_cols )
451+ imputed_np = KNNImputer (n_neighbors = 2 ).fit_transform (to_impute_df .to_numpy ())
452+ imputed_df = pl .concat (
453+ [
454+ df .select (index_cols ),
455+ pl .DataFrame (imputed_np , schema = to_impute_df .columns ),
456+ ],
457+ how = "horizontal" ,
458+ )
459+ assert imputed_df .null_count ().sum_horizontal ().item () == 0 , (
460+ "Null remaining in data"
461+ )
462+ return imputed_df
463+
448464 def _month_in_season (self , date : datetime .date ) -> int :
449465 assert date .day == 1
450466 year = date .year
0 commit comments