Skip to content

Commit 1812ba7

Browse files
committed
Jettison data classes
1 parent dd6860c commit 1812ba7

6 files changed

Lines changed: 18 additions & 407 deletions

File tree

iup/__init__.py

Lines changed: 0 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -1,163 +1,4 @@
1-
from typing import List
2-
31
import polars as pl
4-
from polars.datatypes.classes import DataTypeClass
5-
6-
7-
class Data(pl.DataFrame):
8-
"""
9-
Abstract class for observed data and forecast data.
10-
"""
11-
12-
def __init__(self, *args, **kwargs):
13-
super().__init__(*args, **kwargs)
14-
self.validate()
15-
16-
def validate(self):
17-
raise NotImplementedError("Subclasses must implement this method.")
18-
19-
def assert_in_schema(self, names_types: dict[str, DataTypeClass]):
20-
"""Verify that columns of the expected types are present in the data frame.
21-
22-
Args:
23-
names_types: Column names and types mapping.
24-
"""
25-
for name, type_ in names_types.items():
26-
if name not in self.schema.names():
27-
raise RuntimeError(f"Column '{name}' not found")
28-
elif (
29-
name in self.schema.names() and (name, type_) not in self.schema.items()
30-
):
31-
actual_type = self.schema.to_python()[name]
32-
raise RuntimeError(
33-
f"Column '{name}' has type {actual_type}, not {type_}"
34-
)
35-
else:
36-
assert (name, type_) in self.schema.items()
37-
38-
39-
class CoverageData(Data):
40-
def validate(self):
41-
"""Must have time_end and estimate columns; can have more."""
42-
self.assert_in_schema({"time_end": pl.Date, "estimate": pl.Float64})
43-
44-
45-
class IncidentCoverageData(CoverageData):
46-
def validate(self):
47-
super().validate()
48-
if not self["estimate"].is_between(-1.0, 1.0).all():
49-
bad_values = (
50-
self.filter(pl.col("estimate").is_between(-1.0, 1.0).not_())["estimate"]
51-
.unique()
52-
.to_list()
53-
)
54-
raise ValueError(
55-
f"Incident coverage `estimate` must be have values between -1 and +1. "
56-
f"Values included {bad_values}"
57-
)
58-
59-
def to_cumulative(
60-
self, groups: List[str,] | None, prev_cumulative: pl.DataFrame | None = None
61-
) -> "CumulativeCoverageData":
62-
"""Convert incident to cumulative coverage data.
63-
64-
Cumulative sum of incident coverage gives the cumulative coverage.
65-
Optionally, additional cumulative coverage from before the start of
66-
the incident data may be provided.
67-
Even if no groups are specified, the data must at least be grouped by season.
68-
69-
Args:
70-
groups: Names of the columns of grouping factors, or None. If `None`, then
71-
data will be grouped by `"season"`.
72-
prev_cumulative: Cumulative coverage from before the start of the incident
73-
data, for each group, or None. If `None`, group by `"season"`.
74-
75-
Returns:
76-
Cumulative coverage on each date in the input incident coverage data.
77-
"""
78-
if groups is None:
79-
groups = ["season"]
80-
81-
out = self.with_columns(estimate=pl.col("estimate").cum_sum().over(groups))
82-
83-
if prev_cumulative is not None:
84-
out = out.join(prev_cumulative, on=groups)
85-
86-
out = out.with_columns(
87-
estimate=pl.col("estimate") + pl.col("last_cumulative")
88-
).drop("last_cumulative")
89-
90-
return CumulativeCoverageData(out)
91-
92-
93-
class CumulativeCoverageData(CoverageData):
94-
def validate(self):
95-
super().validate()
96-
assert self["estimate"].is_between(0.0, 1.0).all(), (
97-
"Cumulative coverage `estimate` must be a proportion"
98-
)
99-
100-
def to_incident(self, groups: List[str,] | None) -> IncidentCoverageData:
101-
"""Convert cumulative to incident coverage data.
102-
103-
Because the first report date for each group is often rollout,
104-
incident coverage on the first report date is 0.
105-
106-
Args:
107-
groups: Names of the columns of grouping factors, or None. If `None`,
108-
then data will be grouped by `"season"`.
109-
110-
Returns:
111-
Incident coverage on each date in the input cumulative coverage data.
112-
"""
113-
if groups is None:
114-
groups = ["season"]
115-
116-
out = self.with_columns(
117-
estimate=pl.col("estimate").diff().over(groups).fill_null(0)
118-
)
119-
120-
return IncidentCoverageData(out)
121-
122-
123-
class QuantileForecast(Data):
124-
"""
125-
Class for forecast with quantiles.
126-
Save for future.
127-
"""
128-
129-
def validate(self):
130-
self.assert_in_schema(
131-
{"time_end": pl.Date, "quantile": pl.Float64, "estimate": pl.Float64}
132-
)
133-
134-
assert self["quantile"].is_between(0.0, 1.0).all(), (
135-
"quantiles must be between 0 and 1"
136-
)
137-
138-
139-
class PointForecast(QuantileForecast):
140-
"""
141-
Class for forecast with point estimate
142-
A subclass when quantile is 50%
143-
For now, enforce the "quantile50" to be "estimate"
144-
"""
145-
146-
def validate(self):
147-
super().validate()
148-
assert (self["quantile"] == 0.50).all()
149-
150-
151-
class SampleForecast(Data):
152-
"""
153-
Class for forecast with posterior distribution.
154-
Save for future.
155-
"""
156-
157-
def validate(self):
158-
self.assert_in_schema(
159-
{"time_end": pl.Date, "sample_id": pl.UInt64, "estimate": pl.Float64}
160-
)
1612

1623

1634
def to_season(

iup/models.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class LPLModel(CoverageModel):
5151

5252
def __init__(
5353
self,
54-
data: iup.CumulativeCoverageData,
54+
data: pl.DataFrame,
5555
forecast_date: datetime.date,
5656
params: dict[str, Any],
5757
season: dict[str, Any],
@@ -361,13 +361,13 @@ def predict(self) -> pl.DataFrame:
361361
)
362362
)
363363

364-
return iup.QuantileForecast(data_pred.explode(["quantile", "estimate"]))
364+
return data_pred.explode(["quantile", "estimate"])
365365

366366

367367
class RFModel(CoverageModel):
368368
def __init__(
369369
self,
370-
data: iup.CumulativeCoverageData,
370+
data: pl.DataFrame,
371371
params: dict[str, Any],
372372
season: dict[str, Any],
373373
forecast_date: datetime.date,
@@ -465,17 +465,15 @@ def predict(self) -> pl.DataFrame:
465465
# make predictions using each tree
466466
y_tree = np.stack([tree.predict(X_pred) for tree in self.model.estimators_])
467467

468-
return iup.QuantileForecast(
469-
pl.concat(
470-
[
471-
self._postprocess(
472-
data_pred=data_pred,
473-
y_pred=np.quantile(y_tree, q=q, axis=0),
474-
quantile=q,
475-
)
476-
for q in self.quantiles
477-
]
478-
)
468+
return pl.concat(
469+
[
470+
self._postprocess(
471+
data_pred=data_pred,
472+
y_pred=np.quantile(y_tree, q=q, axis=0),
473+
quantile=q,
474+
)
475+
for q in self.quantiles
476+
]
479477
)
480478

481479
def _postprocess(

scripts/preprocess.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import polars as pl
66
import yaml
77

8-
from iup import CumulativeCoverageData, to_season
8+
from iup import to_season
99

1010

1111
def preprocess(
12-
raw_data: pl.LazyFrame,
12+
raw_data: pl.DataFrame,
1313
start_year: int,
1414
end_year: int,
1515
season_start_month: int,
@@ -18,7 +18,7 @@ def preprocess(
1818
season_end_day: int,
1919
geographies: Optional[List[str] | None],
2020
date_col: str = "time_end",
21-
) -> CumulativeCoverageData:
21+
) -> pl.DataFrame:
2222
"""
2323
Preprocess the raw data (Filter the raw data with certain states and seasons, add season column).
2424
@@ -37,13 +37,13 @@ def preprocess(
3737
3838
"""
3939

40-
def geo_filter(df: pl.LazyFrame) -> pl.LazyFrame:
40+
def geo_filter(df: pl.DataFrame) -> pl.DataFrame:
4141
if geographies is None:
4242
return df
4343
else:
4444
return df.filter(pl.col("geography").is_in(geographies))
4545

46-
data = (
46+
return (
4747
raw_data.filter(
4848
pl.col("geography_type") == pl.lit("admin1"),
4949
pl.col("geography")
@@ -66,11 +66,8 @@ def geo_filter(df: pl.LazyFrame) -> pl.LazyFrame:
6666
pl.col("season").is_null().not_(),
6767
)
6868
.pipe(geo_filter)
69-
.collect()
7069
)
7170

72-
return CumulativeCoverageData(data)
73-
7471

7572
if __name__ == "__main__":
7673
p = argparse.ArgumentParser()
@@ -82,7 +79,7 @@ def geo_filter(df: pl.LazyFrame) -> pl.LazyFrame:
8279
with open(args.config) as f:
8380
config = yaml.safe_load(f)
8481

85-
raw_data = pl.scan_parquet(args.input)
82+
raw_data = pl.read_parquet(args.input)
8683

8784
assert isinstance(config, dict)
8885
geographies = config.get("geographies", None)

tests/__init__.py

Whitespace-only changes.

tests/conftest.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import polars as pl
22
import pytest
33

4-
import iup
5-
64

75
@pytest.fixture
86
def frame():
@@ -80,6 +78,4 @@ def frame():
8078
schema_overrides={"time_end": pl.Date},
8179
)
8280

83-
frame = iup.CumulativeCoverageData(frame)
84-
8581
return frame

0 commit comments

Comments
 (0)