|
4 | 4 | os.environ["JAX_PLATFORMS"] = "cpu" |
5 | 5 |
|
6 | 6 | import abc |
7 | | -import calendar |
8 | 7 | import datetime |
9 | 8 | import inspect |
10 | | -from typing import Any, List |
| 9 | +from typing import Any |
11 | 10 |
|
12 | 11 | import jax.numpy as jnp |
13 | 12 | import numpy as np |
@@ -381,202 +380,138 @@ def __init__( |
381 | 380 | self.quantiles = quantiles |
382 | 381 | self.season = season |
383 | 382 | self.params = params |
384 | | - self.months = self._month_order(self.season["start_month"]) |
385 | | - self.end_month_index = self.months.index( |
386 | | - datetime.date( |
387 | | - self.season["end_year"], |
388 | | - self.season["end_month"], |
389 | | - self.season["end_day"], |
390 | | - ).strftime("%b") |
391 | | - ) |
392 | 383 |
|
393 | 384 | # other params include max_depth, min_samples_split, min_samples_leaf |
394 | 385 | rf_keys = {"n_estimators"} |
395 | | - |
396 | 386 | self.rf_params = {k: v for k, v in params.items() if k in rf_keys} |
397 | 387 |
|
398 | | - self.data = self._preprocess( |
399 | | - self.raw_data, |
400 | | - self.months, |
401 | | - self.end_month_index, |
402 | | - self.date_column, |
403 | | - ) |
| 388 | + data_t = self.raw_data.with_columns( |
| 389 | + t=pl.col(self.date_column).map_elements(self._month_in_season) |
| 390 | + ).sort(["season", "geography", "t"]) |
404 | 391 |
|
405 | | - @classmethod |
406 | | - def _preprocess( |
407 | | - cls, data: pl.DataFrame, months, end_month_index, date_column |
408 | | - ) -> pl.DataFrame: |
409 | | - out = ( |
410 | | - data.with_columns( |
411 | | - t=pl.col(date_column) |
412 | | - .dt.to_string("%b") |
413 | | - .map_elements(lambda x: months.index(x) - end_month_index, pl.Int64) |
414 | | - ) |
415 | | - .filter(pl.col("t").is_between(1 - end_month_index, 0)) |
416 | | - .select(["season", "geography", "t", "estimate"]) |
417 | | - .with_columns(pl.format("t={}", pl.col("t"))) |
418 | | - .pivot(on="t", values="estimate") |
| 392 | + # preprocessing |
| 393 | + self.date_crosswalk = data_t.select("season", date_column, "t").unique() |
| 394 | + |
| 395 | + self.data = ( |
| 396 | + data_t.select(["season", "geography", "t", "estimate"]) |
| 397 | + .pivot(on="t", values="estimate", sort_columns=True) |
| 398 | + # impute zero uptake at start of season |
| 399 | + .with_columns(pl.coalesce(pl.col("0"), 0.0)) |
| 400 | + # drop season/geo's with any other missing values |
419 | 401 | .drop_nulls() |
420 | 402 | .sort(["season", "geography"]) |
421 | 403 | ) |
422 | 404 |
|
423 | | - return out |
424 | | - |
425 | | - def fit(self) -> Self: |
426 | | - self.enc = CoverageEncoder() |
427 | | - self.enc.fit(self.data) |
428 | | - |
429 | | - target_season = iup.date_to_season( |
430 | | - pl.lit(self.forecast_date), |
431 | | - season_start_month=self.season["start_month"], |
432 | | - season_start_day=self.season["start_day"], |
433 | | - ) |
434 | | - |
435 | | - forecast_t = ( |
436 | | - self.months.index(self.forecast_date.strftime("%b")) - self.end_month_index |
437 | | - ) |
| 405 | + self.forecast_season = pl.select( |
| 406 | + iup.to_season( |
| 407 | + pl.lit(self.forecast_date), |
| 408 | + season_start_month=self.season["start_month"], |
| 409 | + season_end_month=self.season["end_month"], |
| 410 | + season_end_day=self.season["end_day"], |
| 411 | + season_start_day=self.season["start_day"], |
| 412 | + ) |
| 413 | + ).item() |
| 414 | + self.forecast_month = self._month_in_season(self.forecast_date) |
| 415 | + |
| 416 | + def _month_in_season(self, date: datetime.date) -> int: |
| 417 | + assert date.day == 1 |
| 418 | + year = date.year |
| 419 | + # start of a season that's in this year |
| 420 | + ssiy = datetime.date(year, self.season["start_month"], self.season["start_day"]) |
| 421 | + |
| 422 | + # season start year |
| 423 | + if date < ssiy: |
| 424 | + ssy = year - 1 |
| 425 | + else: |
| 426 | + ssy = year |
438 | 427 |
|
439 | | - end_date = datetime.date( |
440 | | - self.season["end_year"], self.season["end_month"], self.season["end_day"] |
441 | | - ) |
| 428 | + return (year - ssy) * 12 + (date.month - self.season["start_month"]) |
442 | 429 |
|
443 | | - # this is true only when target_season is the last season in the data, which is our case for now |
444 | | - assert self.data.select(target_season).item() == self.data["season"].max() |
445 | | - data_fit = self.data.filter(pl.col("season") != target_season) |
| 430 | + def fit(self) -> Self: |
| 431 | + self.enc = Encoder().fit(self.data) |
446 | 432 |
|
447 | | - # fit all the data after forecast_t |
448 | | - features = ["season", "geography"] + [ |
449 | | - f"t={t}" |
450 | | - for t in range( |
451 | | - 1 - self.months.index(end_date.strftime("%b")), forecast_t + 1 |
452 | | - ) |
| 433 | + self.X_features = ["season", "geography"] + [ |
| 434 | + str(t) |
| 435 | + for t in range(0, self.forecast_month + 1) |
| 436 | + if str(t) in self.data.columns |
453 | 437 | ] |
| 438 | + self.y_features = [ |
| 439 | + str(t) |
| 440 | + for t in range(self.forecast_month + 1, 12) |
| 441 | + if str(t) in self.data.columns |
| 442 | + ] |
| 443 | + |
| 444 | + # fit the model |
| 445 | + data_fit = self.data.filter(pl.col("season") < self.forecast_season) |
| 446 | + X_fit = self.enc.encode(data_fit.select(self.X_features)) |
| 447 | + y_fit = data_fit.select(self.y_features).to_numpy() |
454 | 448 |
|
455 | | - X_fit = self.enc.encode(data_fit.select(features)) |
456 | | - y_fit = data_fit.select( |
457 | | - [f"t={target_t}" for target_t in range(forecast_t + 1, 1)] |
458 | | - ).to_numpy() |
| 449 | + # sklearn complains if you pass a column vector rather than a 1d array |
| 450 | + if y_fit.shape[1] == 1: |
| 451 | + y_fit = y_fit.ravel() |
459 | 452 |
|
460 | 453 | self.model = RandomForestRegressor(**self.rf_params).fit(X_fit, y_fit) |
461 | 454 |
|
462 | 455 | return self |
463 | 456 |
|
464 | 457 | def predict(self) -> pl.DataFrame: |
465 | | - assert self.model is not None |
466 | | - |
467 | | - # include in-sample and out-of-sample prediction |
468 | | - data_pred = self.data |
469 | | - |
470 | | - forecast_t = ( |
471 | | - self.months.index(self.forecast_date.strftime("%b")) - self.end_month_index |
472 | | - ) |
473 | | - |
474 | | - end_date = datetime.date( |
475 | | - self.season["end_year"], self.season["end_month"], self.season["end_day"] |
476 | | - ) |
477 | | - |
478 | | - features = ["season", "geography"] + [ |
479 | | - f"t={t}" |
480 | | - for t in range( |
481 | | - 1 - self.months.index(end_date.strftime("%b")), forecast_t + 1 |
482 | | - ) |
483 | | - ] |
484 | | - |
485 | | - X_pred = self.enc.encode(data_pred.select(features)) |
486 | | - t_cols = [f"t={t}" for t in range(forecast_t + 1, 1)] |
487 | | - index_cols = ["season", "geography", "quantile"] |
488 | | - |
489 | | - pred = np.array([tree.predict(X_pred) for tree in self.model.estimators_]) |
490 | | - pred = {f"q={k}": np.quantile(pred, k, axis=0) for k in self.quantiles} |
491 | | - all_pred = pl.DataFrame() |
492 | | - |
493 | | - for k, v in pred.items(): |
494 | | - df = pl.DataFrame(v, schema=[f"t={t}" for t in range(forecast_t + 1, 1)]) |
495 | | - df = df.with_columns( |
496 | | - quantile=pl.lit(k).str.replace("q=", "").cast(pl.Float64) |
497 | | - ) |
498 | | - |
499 | | - pred_df = pl.concat( |
500 | | - [data_pred.select(["season", "geography"]), df], how="horizontal" |
| 458 | + # make the forecast |
| 459 | + data_pred = self.data.filter(pl.col("season") >= self.forecast_season) |
| 460 | + |
| 461 | + X_data = data_pred.select(self.X_features) |
| 462 | + assert X_data.shape[0] > 0, f"RF prediction for {self.forecast_date} failed" |
| 463 | + X_pred = self.enc.encode(X_data) |
| 464 | + |
| 465 | + # make predictions using each tree |
| 466 | + y_tree = np.stack([tree.predict(X_pred) for tree in self.model.estimators_]) |
| 467 | + |
| 468 | + return iup.QuantileForecast( |
| 469 | + pl.concat( |
| 470 | + [ |
| 471 | + self._postprocess( |
| 472 | + data_pred=data_pred, |
| 473 | + y_pred=np.quantile(y_tree, q=q, axis=0), |
| 474 | + quantile=q, |
| 475 | + ) |
| 476 | + for q in self.quantiles |
| 477 | + ] |
501 | 478 | ) |
| 479 | + ) |
502 | 480 |
|
503 | | - all_pred = pl.concat([all_pred, pred_df]) |
| 481 | + def _postprocess( |
| 482 | + self, data_pred: pl.DataFrame, y_pred: np.ndarray, quantile: float |
| 483 | + ) -> pl.DataFrame: |
| 484 | + if len(y_pred.shape) == 1: |
| 485 | + y_pred = y_pred.reshape(-1, 1) |
504 | 486 |
|
505 | | - all_pred = ( |
506 | | - all_pred.unpivot( |
507 | | - on=t_cols, |
508 | | - index=index_cols, |
509 | | - variable_name="target_t", |
| 487 | + return ( |
| 488 | + data_pred.select(["season", "geography"]) |
| 489 | + .hstack(pl.DataFrame(y_pred, schema=self.y_features)) |
| 490 | + .unpivot( |
| 491 | + on=self.y_features, |
| 492 | + index=["season", "geography"], |
| 493 | + variable_name="t", |
510 | 494 | value_name="estimate", |
511 | 495 | ) |
512 | | - .with_columns( |
513 | | - forecast_date=self.forecast_date, |
514 | | - target_index=( |
515 | | - pl.col("target_t").str.replace("t=", "").cast(pl.Int8) |
516 | | - + self.end_month_index |
517 | | - ), # convert back to month index |
518 | | - target_year=pl.col("season").str.extract(r"^(\d{4})/\d{4}"), |
519 | | - ) |
520 | | - .with_columns( |
521 | | - season_start_date=pl.date( |
522 | | - pl.col("target_year"), |
523 | | - self.season["start_month"], |
524 | | - self.season["start_day"], |
525 | | - ), |
526 | | - target_index=pl.format("{}mo", pl.col("target_index")), |
527 | | - ) |
528 | | - .with_columns( |
529 | | - pl.col("season_start_date") |
530 | | - .dt.offset_by(pl.col("target_index")) |
531 | | - .alias("time_end") |
532 | | - ) |
533 | | - .drop(["target_index", "target_year", "season_start_date", "target_t"]) |
| 496 | + .with_columns(pl.col("t").cast(pl.Int64)) |
| 497 | + .join(self.date_crosswalk, on=["season", "t"], how="left") |
| 498 | + .drop("t") |
| 499 | + .with_columns(forecast_date=self.forecast_date, quantile=quantile) |
534 | 500 | ) |
535 | 501 |
|
536 | | - return all_pred |
537 | | - |
538 | | - @staticmethod |
539 | | - def _month_order(season_start_month: int) -> List[str]: |
540 | | - return [ |
541 | | - calendar.month_abbr[i] |
542 | | - for i in list(range(season_start_month, 12 + 1)) |
543 | | - + list(range(1, season_start_month)) |
544 | | - ] |
545 | 502 |
|
546 | | - |
547 | | -class CoverageEncoder: |
548 | | - def __init__(self, categorical_feature_names: tuple = ("season", "geography")): |
549 | | - self.categorical_feature_names = categorical_feature_names |
| 503 | +class Encoder: |
| 504 | + def __init__(self, categorical_features: tuple = ("season", "geography")): |
| 505 | + self.categorical_features = categorical_features |
550 | 506 | self.enc = OneHotEncoder(sparse_output=False) |
551 | | - self.categorical_features = None |
552 | | - |
553 | | - def fit(self, data: pl.DataFrame): |
554 | | - self.enc.fit(data.select(self.categorical_feature_names).to_numpy()) |
555 | 507 |
|
556 | | - self.categorical_features = list( |
557 | | - self._iter_features(self.categorical_feature_names, self.enc.categories_) |
558 | | - ) |
559 | | - |
560 | | - @staticmethod |
561 | | - def _iter_features(names, categories): |
562 | | - for feature, values in zip(names, categories): |
563 | | - for value in values: |
564 | | - yield (feature, value) |
| 508 | + def fit(self, data: pl.DataFrame) -> Self: |
| 509 | + self.enc.fit(data.select(self.categorical_features).to_numpy()) |
| 510 | + return self |
565 | 511 |
|
566 | 512 | def encode(self, data: pl.DataFrame) -> np.ndarray: |
567 | | - X_enc = self.enc.transform( |
568 | | - data.select(self.categorical_feature_names).to_numpy() |
569 | | - ) |
570 | | - X_pass = data.drop(self.categorical_feature_names).to_numpy() |
| 513 | + X_enc = self.enc.transform(data.select(self.categorical_features).to_numpy()) |
| 514 | + X_pass = data.drop(self.categorical_features).to_numpy() |
571 | 515 |
|
572 | 516 | assert isinstance(X_enc, np.ndarray) |
573 | 517 | return np.asarray(np.hstack((X_enc, X_pass))) |
574 | | - |
575 | | - def categories(self, data: pl.DataFrame): |
576 | | - if self.categorical_features is None: |
577 | | - raise RuntimeError |
578 | | - else: |
579 | | - return self.categorical_features + [ |
580 | | - ("unencoded", col) |
581 | | - for col in data.drop(self.categorical_feature_names).columns |
582 | | - ] |
0 commit comments