Skip to content

Commit 0cc87d2

Browse files
authored
RF pipeline (#308)
- Predict multiple dates using the same forest, as per #307 - Make a 3-way distinction for seasons: you can be in last/this, this/next, or in neither. This way we can trim out the summer data that causes problems. - Use the approach of #282 to generate prediction quantiles - Jettison the concept of the "fit," which is incompatible with the RF model
1 parent 9833e1b commit 0cc87d2

8 files changed

Lines changed: 247 additions & 322 deletions

File tree

iup/__init__.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -160,30 +160,44 @@ def validate(self):
160160
)
161161

162162

163-
def date_to_season(
164-
date: pl.Expr, season_start_month: int, season_start_day: int = 1
163+
def to_season(
164+
date: pl.Expr,
165+
season_start_month: int,
166+
season_end_month: int,
167+
season_start_day: int = 1,
168+
season_end_day: int = 1,
165169
) -> pl.Expr:
166-
"""Extract the overwinter disease season from a date.
170+
"""
171+
Identify the overwinter season from a date.
167172
168-
Dates in year Y before the season start (e.g., Sep 1) are in the second part of
169-
the season (i.e., in season Y-1/Y). Dates in year Y after the season start are in
170-
season Y/Y+1. E.g., 2023-10-07 and 2024-04-18 are both in "2023/2024".
173+
Every year, there is a season end (e.g., May 1) and a season start (e.g., Sep 1).
174+
Dates before the season end are associated with the prior season (e.g., Feb 1, 2020
175+
belongs to 2019/2020 season). Dates after the season start are associated with the
176+
next season (e.g., Oct 1, 2020 belongs to 2020/2021). Dates between the season end
177+
and season start are not in any season (e.g., June 1).
171178
172179
Args:
173-
date: Dates in an coverage data frame.
174-
season_start_month: First month of the overwinter disease season.
175-
season_start_day: First day of the first month of the overwinter disease season.
180+
date: dates
181+
season_start_month: first month
182+
season_end_month: last month
183+
season_start_day: first day
184+
season_end_day: last day
176185
177186
Returns:
178-
Seasons for each date.
187+
season like "2020/2021"
179188
"""
189+
assert (season_start_month, season_start_day) > (
190+
season_end_month,
191+
season_end_day,
192+
), "Only overwinter seasons are supported"
180193

181-
# for every date, figure out the season breakpoint in that year
182-
season_start = pl.date(date.dt.year(), season_start_month, season_start_day)
194+
# year of this date
195+
y = date.dt.year()
196+
# start and end dates of seasons in this year
197+
end = pl.date(y, season_end_month, season_end_day)
198+
start = pl.date(y, season_start_month, season_start_day)
183199

184-
# what is the first year in the two-year season indicator?
185-
date_year = date.dt.year()
186-
year1 = pl.when(date < season_start).then(date_year - 1).otherwise(date_year)
200+
# first year of the two-year season
201+
sy1 = pl.when(date <= end).then(y - 1).when(date >= start).then(y).otherwise(None)
187202

188-
year2 = year1 + 1
189-
return pl.format("{}/{}", year1, year2)
203+
return pl.when(sy1.is_null()).then(None).otherwise(pl.format("{}/{}", sy1, sy1 + 1))

iup/models.py

Lines changed: 100 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
os.environ["JAX_PLATFORMS"] = "cpu"
55

66
import abc
7-
import calendar
87
import datetime
98
import inspect
10-
from typing import Any, List
9+
from typing import Any
1110

1211
import jax.numpy as jnp
1312
import numpy as np
@@ -381,202 +380,138 @@ def __init__(
381380
self.quantiles = quantiles
382381
self.season = season
383382
self.params = params
384-
self.months = self._month_order(self.season["start_month"])
385-
self.end_month_index = self.months.index(
386-
datetime.date(
387-
self.season["end_year"],
388-
self.season["end_month"],
389-
self.season["end_day"],
390-
).strftime("%b")
391-
)
392383

393384
# other params include max_depth, min_samples_split, min_samples_leaf
394385
rf_keys = {"n_estimators"}
395-
396386
self.rf_params = {k: v for k, v in params.items() if k in rf_keys}
397387

398-
self.data = self._preprocess(
399-
self.raw_data,
400-
self.months,
401-
self.end_month_index,
402-
self.date_column,
403-
)
388+
data_t = self.raw_data.with_columns(
389+
t=pl.col(self.date_column).map_elements(self._month_in_season)
390+
).sort(["season", "geography", "t"])
404391

405-
@classmethod
406-
def _preprocess(
407-
cls, data: pl.DataFrame, months, end_month_index, date_column
408-
) -> pl.DataFrame:
409-
out = (
410-
data.with_columns(
411-
t=pl.col(date_column)
412-
.dt.to_string("%b")
413-
.map_elements(lambda x: months.index(x) - end_month_index, pl.Int64)
414-
)
415-
.filter(pl.col("t").is_between(1 - end_month_index, 0))
416-
.select(["season", "geography", "t", "estimate"])
417-
.with_columns(pl.format("t={}", pl.col("t")))
418-
.pivot(on="t", values="estimate")
392+
# preprocessing
393+
self.date_crosswalk = data_t.select("season", date_column, "t").unique()
394+
395+
self.data = (
396+
data_t.select(["season", "geography", "t", "estimate"])
397+
.pivot(on="t", values="estimate", sort_columns=True)
398+
# impute zero uptake at start of season
399+
.with_columns(pl.coalesce(pl.col("0"), 0.0))
400+
# drop season/geo's with any other missing values
419401
.drop_nulls()
420402
.sort(["season", "geography"])
421403
)
422404

423-
return out
424-
425-
def fit(self) -> Self:
426-
self.enc = CoverageEncoder()
427-
self.enc.fit(self.data)
428-
429-
target_season = iup.date_to_season(
430-
pl.lit(self.forecast_date),
431-
season_start_month=self.season["start_month"],
432-
season_start_day=self.season["start_day"],
433-
)
434-
435-
forecast_t = (
436-
self.months.index(self.forecast_date.strftime("%b")) - self.end_month_index
437-
)
405+
self.forecast_season = pl.select(
406+
iup.to_season(
407+
pl.lit(self.forecast_date),
408+
season_start_month=self.season["start_month"],
409+
season_end_month=self.season["end_month"],
410+
season_end_day=self.season["end_day"],
411+
season_start_day=self.season["start_day"],
412+
)
413+
).item()
414+
self.forecast_month = self._month_in_season(self.forecast_date)
415+
416+
def _month_in_season(self, date: datetime.date) -> int:
417+
assert date.day == 1
418+
year = date.year
419+
# start of a season that's in this year
420+
ssiy = datetime.date(year, self.season["start_month"], self.season["start_day"])
421+
422+
# season start year
423+
if date < ssiy:
424+
ssy = year - 1
425+
else:
426+
ssy = year
438427

439-
end_date = datetime.date(
440-
self.season["end_year"], self.season["end_month"], self.season["end_day"]
441-
)
428+
return (year - ssy) * 12 + (date.month - self.season["start_month"])
442429

443-
# this is true only when target_season is the last season in the data, which is our case for now
444-
assert self.data.select(target_season).item() == self.data["season"].max()
445-
data_fit = self.data.filter(pl.col("season") != target_season)
430+
def fit(self) -> Self:
431+
self.enc = Encoder().fit(self.data)
446432

447-
# fit all the data after forecast_t
448-
features = ["season", "geography"] + [
449-
f"t={t}"
450-
for t in range(
451-
1 - self.months.index(end_date.strftime("%b")), forecast_t + 1
452-
)
433+
self.X_features = ["season", "geography"] + [
434+
str(t)
435+
for t in range(0, self.forecast_month + 1)
436+
if str(t) in self.data.columns
453437
]
438+
self.y_features = [
439+
str(t)
440+
for t in range(self.forecast_month + 1, 12)
441+
if str(t) in self.data.columns
442+
]
443+
444+
# fit the model
445+
data_fit = self.data.filter(pl.col("season") < self.forecast_season)
446+
X_fit = self.enc.encode(data_fit.select(self.X_features))
447+
y_fit = data_fit.select(self.y_features).to_numpy()
454448

455-
X_fit = self.enc.encode(data_fit.select(features))
456-
y_fit = data_fit.select(
457-
[f"t={target_t}" for target_t in range(forecast_t + 1, 1)]
458-
).to_numpy()
449+
# sklearn complains if you pass a column vector rather than a 1d array
450+
if y_fit.shape[1] == 1:
451+
y_fit = y_fit.ravel()
459452

460453
self.model = RandomForestRegressor(**self.rf_params).fit(X_fit, y_fit)
461454

462455
return self
463456

464457
def predict(self) -> pl.DataFrame:
465-
assert self.model is not None
466-
467-
# include in-sample and out-of-sample prediction
468-
data_pred = self.data
469-
470-
forecast_t = (
471-
self.months.index(self.forecast_date.strftime("%b")) - self.end_month_index
472-
)
473-
474-
end_date = datetime.date(
475-
self.season["end_year"], self.season["end_month"], self.season["end_day"]
476-
)
477-
478-
features = ["season", "geography"] + [
479-
f"t={t}"
480-
for t in range(
481-
1 - self.months.index(end_date.strftime("%b")), forecast_t + 1
482-
)
483-
]
484-
485-
X_pred = self.enc.encode(data_pred.select(features))
486-
t_cols = [f"t={t}" for t in range(forecast_t + 1, 1)]
487-
index_cols = ["season", "geography", "quantile"]
488-
489-
pred = np.array([tree.predict(X_pred) for tree in self.model.estimators_])
490-
pred = {f"q={k}": np.quantile(pred, k, axis=0) for k in self.quantiles}
491-
all_pred = pl.DataFrame()
492-
493-
for k, v in pred.items():
494-
df = pl.DataFrame(v, schema=[f"t={t}" for t in range(forecast_t + 1, 1)])
495-
df = df.with_columns(
496-
quantile=pl.lit(k).str.replace("q=", "").cast(pl.Float64)
497-
)
498-
499-
pred_df = pl.concat(
500-
[data_pred.select(["season", "geography"]), df], how="horizontal"
458+
# make the forecast
459+
data_pred = self.data.filter(pl.col("season") >= self.forecast_season)
460+
461+
X_data = data_pred.select(self.X_features)
462+
assert X_data.shape[0] > 0, f"RF prediction for {self.forecast_date} failed"
463+
X_pred = self.enc.encode(X_data)
464+
465+
# make predictions using each tree
466+
y_tree = np.stack([tree.predict(X_pred) for tree in self.model.estimators_])
467+
468+
return iup.QuantileForecast(
469+
pl.concat(
470+
[
471+
self._postprocess(
472+
data_pred=data_pred,
473+
y_pred=np.quantile(y_tree, q=q, axis=0),
474+
quantile=q,
475+
)
476+
for q in self.quantiles
477+
]
501478
)
479+
)
502480

503-
all_pred = pl.concat([all_pred, pred_df])
481+
def _postprocess(
482+
self, data_pred: pl.DataFrame, y_pred: np.ndarray, quantile: float
483+
) -> pl.DataFrame:
484+
if len(y_pred.shape) == 1:
485+
y_pred = y_pred.reshape(-1, 1)
504486

505-
all_pred = (
506-
all_pred.unpivot(
507-
on=t_cols,
508-
index=index_cols,
509-
variable_name="target_t",
487+
return (
488+
data_pred.select(["season", "geography"])
489+
.hstack(pl.DataFrame(y_pred, schema=self.y_features))
490+
.unpivot(
491+
on=self.y_features,
492+
index=["season", "geography"],
493+
variable_name="t",
510494
value_name="estimate",
511495
)
512-
.with_columns(
513-
forecast_date=self.forecast_date,
514-
target_index=(
515-
pl.col("target_t").str.replace("t=", "").cast(pl.Int8)
516-
+ self.end_month_index
517-
), # convert back to month index
518-
target_year=pl.col("season").str.extract(r"^(\d{4})/\d{4}"),
519-
)
520-
.with_columns(
521-
season_start_date=pl.date(
522-
pl.col("target_year"),
523-
self.season["start_month"],
524-
self.season["start_day"],
525-
),
526-
target_index=pl.format("{}mo", pl.col("target_index")),
527-
)
528-
.with_columns(
529-
pl.col("season_start_date")
530-
.dt.offset_by(pl.col("target_index"))
531-
.alias("time_end")
532-
)
533-
.drop(["target_index", "target_year", "season_start_date", "target_t"])
496+
.with_columns(pl.col("t").cast(pl.Int64))
497+
.join(self.date_crosswalk, on=["season", "t"], how="left")
498+
.drop("t")
499+
.with_columns(forecast_date=self.forecast_date, quantile=quantile)
534500
)
535501

536-
return all_pred
537-
538-
@staticmethod
539-
def _month_order(season_start_month: int) -> List[str]:
540-
return [
541-
calendar.month_abbr[i]
542-
for i in list(range(season_start_month, 12 + 1))
543-
+ list(range(1, season_start_month))
544-
]
545502

546-
547-
class CoverageEncoder:
548-
def __init__(self, categorical_feature_names: tuple = ("season", "geography")):
549-
self.categorical_feature_names = categorical_feature_names
503+
class Encoder:
504+
def __init__(self, categorical_features: tuple = ("season", "geography")):
505+
self.categorical_features = categorical_features
550506
self.enc = OneHotEncoder(sparse_output=False)
551-
self.categorical_features = None
552-
553-
def fit(self, data: pl.DataFrame):
554-
self.enc.fit(data.select(self.categorical_feature_names).to_numpy())
555507

556-
self.categorical_features = list(
557-
self._iter_features(self.categorical_feature_names, self.enc.categories_)
558-
)
559-
560-
@staticmethod
561-
def _iter_features(names, categories):
562-
for feature, values in zip(names, categories):
563-
for value in values:
564-
yield (feature, value)
508+
def fit(self, data: pl.DataFrame) -> Self:
509+
self.enc.fit(data.select(self.categorical_features).to_numpy())
510+
return self
565511

566512
def encode(self, data: pl.DataFrame) -> np.ndarray:
567-
X_enc = self.enc.transform(
568-
data.select(self.categorical_feature_names).to_numpy()
569-
)
570-
X_pass = data.drop(self.categorical_feature_names).to_numpy()
513+
X_enc = self.enc.transform(data.select(self.categorical_features).to_numpy())
514+
X_pass = data.drop(self.categorical_features).to_numpy()
571515

572516
assert isinstance(X_enc, np.ndarray)
573517
return np.asarray(np.hstack((X_enc, X_pass)))
574-
575-
def categories(self, data: pl.DataFrame):
576-
if self.categorical_features is None:
577-
raise RuntimeError
578-
else:
579-
return self.categorical_features + [
580-
("unencoded", col)
581-
for col in data.drop(self.categorical_feature_names).columns
582-
]

0 commit comments

Comments
 (0)