Skip to content

Commit 0cd12f3

Browse files
authored
RF: Impute missing coverages (#318)
- Resolves #317 - Resolves #300
1 parent f5f1746 commit 0cd12f3

1 file changed

Lines changed: 20 additions & 4 deletions

File tree

iup/__init__.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from jax import random
1717
from numpyro.infer import MCMC, NUTS, init_to_sample
1818
from sklearn.ensemble import RandomForestRegressor
19+
from sklearn.impute import KNNImputer
1920
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
2021
from typing_extensions import Self
2122

@@ -427,11 +428,8 @@ def __init__(
427428
self.data = (
428429
data_t.select(["season", "geography", "t", "estimate"])
429430
.pivot(on="t", values="estimate", sort_columns=True)
430-
# impute zero uptake at start of season
431-
.with_columns(pl.coalesce(pl.col("0"), 0.0))
432-
# drop season/geo's with any other missing values
433-
.drop_nulls()
434431
.sort(["season", "geography"])
432+
.pipe(self._impute)
435433
)
436434

437435
self.forecast_season = pl.select(
@@ -445,6 +443,24 @@ def __init__(
445443
).item()
446444
self.forecast_month = self._month_in_season(self.forecast_date)
447445

446+
@staticmethod
447+
def _impute(
448+
df: pl.DataFrame, index_cols: tuple[str, ...] = ("season", "geography")
449+
):
450+
to_impute_df = df.drop(index_cols)
451+
imputed_np = KNNImputer(n_neighbors=2).fit_transform(to_impute_df.to_numpy())
452+
imputed_df = pl.concat(
453+
[
454+
df.select(index_cols),
455+
pl.DataFrame(imputed_np, schema=to_impute_df.columns),
456+
],
457+
how="horizontal",
458+
)
459+
assert imputed_df.null_count().sum_horizontal().item() == 0, (
460+
"Null remaining in data"
461+
)
462+
return imputed_df
463+
448464
def _month_in_season(self, date: datetime.date) -> int:
449465
assert date.day == 1
450466
year = date.year

0 commit comments

Comments
 (0)