diff --git a/pyproject.toml b/pyproject.toml index 5e4ec54..38dd5d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ classifiers = [ dynamic = ["version"] dependencies = [ - "iddata @ git+https://github.com/reichlab/iddata", + "iddata @ git+https://github.com/reichlab/iddata@7b86ad0e513423faa8d327426c18350dfdfa07f0", "lightgbm", "numpy", "pandas~=2.0", # pandas 3.0 breaks compatibility; remove cap once validated diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 7b2b7cc..d88e116 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -38,7 +38,7 @@ frozenlist==1.5.0 # aiosignal fsspec==2024.10.0 # via s3fs -iddata @ git+https://github.com/reichlab/iddata@5a7e74d7823d39b8a8ef6334c5191e440bc669d8 +iddata @ git+https://github.com/reichlab/iddata@7b86ad0e513423faa8d327426c18350dfdfa07f0 # via idmodels (pyproject.toml) identify==2.6.1 # via pre-commit diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 7615a50..7e1ddda 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -30,7 +30,7 @@ frozenlist==1.5.0 # aiosignal fsspec==2024.10.0 # via s3fs -iddata @ git+https://github.com/reichlab/iddata@5a7e74d7823d39b8a8ef6334c5191e440bc669d8 +iddata @ git+https://github.com/reichlab/iddata@7b86ad0e513423faa8d327426c18350dfdfa07f0 # via idmodels (pyproject.toml) idna==3.10 # via yarl diff --git a/src/app.py b/src/app.py new file mode 100644 index 0000000..0a1a704 --- /dev/null +++ b/src/app.py @@ -0,0 +1,26 @@ +import pickle +from pathlib import Path + +from idmodels.sarix import SARIXModel + + +def main(): + pkl_dir = Path("/Users/cornell/IdeaProjects/operational-models/covid_gbqr/") + # pkl_dir = Path('/Users/cornell/IdeaProjects/operational-models/covid_ar6_pooled/') + pkl_mc_rc_pairs = [ + # ('2025-08-02-model_config.pkl', '2025-08-02-run_config.pkl'), # works -> 2025-08-02-UMass-gbqr.csv + ("2025-08-09-model_config.pkl", "2025-08-09-run_config.pkl"), # fails -> 2025-08-09-UMass-gbqr.csv + ] + for model_config_file_name, run_config_file_name in pkl_mc_rc_pairs: + with open(pkl_dir / model_config_file_name, "rb") as mc_fp, open(pkl_dir / run_config_file_name, "rb") as rc_fp: + model_config = pickle.load(mc_fp) + run_config = pickle.load(rc_fp) + print("*", model_config_file_name, ",", run_config_file_name) + print("yy", model_config) + model_config.x = [] + model = SARIXModel(model_config) + model.run(run_config) + + +if __name__ == "__main__": + main() diff --git a/src/idmodels/__init__.py b/src/idmodels/__init__.py index 3fb0078..5ba7419 100644 --- a/src/idmodels/__init__.py +++ b/src/idmodels/__init__.py @@ -1,17 +1,28 @@ +from iddata.enums import Disease, SourceType + from idmodels.config import ( - DataSource, - Disease, - GBQRModelConfig, - PoolingStrategy, - PowerTransform, - RunConfig, - SARIXFourierModelConfig, - SARIXModelConfig, + GBQRModelConfig, + PoolingStrategy, + PowerTransform, + RunConfig, + SARIXFourierModelConfig, + SARIXModelConfig, ) from idmodels.gbqr import GBQRModel from idmodels.sarix import SARIXFourierModel, SARIXModel -__all__ = ["DataSource", "Disease", "GBQRModel", "GBQRModelConfig", "PoolingStrategy", "PowerTransform", "RunConfig", - "SARIXFourierModel", "SARIXFourierModelConfig", "SARIXModel", "SARIXModelConfig"] +__all__ = [ + "Disease", + "GBQRModel", + "GBQRModelConfig", + "PoolingStrategy", + "PowerTransform", + "RunConfig", + "SARIXFourierModel", + "SARIXFourierModelConfig", + "SARIXModel", + "SARIXModelConfig", + "SourceType", +] __version__ = "1.3.1" diff --git a/src/idmodels/config.py b/src/idmodels/config.py index f4574b6..26278c5 100644 --- a/src/idmodels/config.py +++ b/src/idmodels/config.py @@ -4,18 +4,10 @@ from enum import Enum from pathlib import Path - -class DataSource(str, Enum): - NHSN = "nhsn" - NSSP = "nssp" - FLUSURVNET = "flusurvnet" - ILINET = "ilinet" - - -class Disease(str, Enum): - FLU = "flu" - COVID = "covid" - RSV = "rsv" +from iddata.enums import ( + Disease, # used internally for RunConfig.disease; import from iddata.enums directly in callers + SourceType, # re-exported for callers: from idmodels.config import SourceType +) class PowerTransform(str, Enum): @@ -30,18 +22,14 @@ class PoolingStrategy(str, Enum): @dataclass class ModelConfig(ABC): - """ - Abstract base for model configuration. - - Holds settings that describe *what* model to run and how it processes data (sources, transforms, pooling). - Not instantiated directly - use :class:`SARIXModelConfig` or :class:`GBQRModelConfig`. - """ + """Abstract base for model configuration.""" model_name: str - sources: list[DataSource] + sources: list[SourceType] fit_locations_separately: bool power_transform: PowerTransform + def __post_init__(self): if type(self) is ModelConfig: raise TypeError("ModelConfig is abstract - use SARIXModelConfig or GBQRModelConfig") @@ -49,11 +37,7 @@ def __post_init__(self): @dataclass class RunConfig: - """ - Run configuration. - - Holds settings that describe a single execution: which disease, which locations, output paths, quantile levels, etc. - """ + """Run configuration: disease, locations, output paths, quantile levels.""" disease: Disease ref_date: datetime.date @@ -95,7 +79,6 @@ class GBQRModelConfig(ModelConfig): reporting_adj: bool = False save_feat_importance: bool = False - # directional wave features (disabled by default) use_directional_waves: bool = False wave_directions: list[str] = field(default_factory=lambda: ["N", "NE", "E", "SE", "S", "SW", "W", "NW"]) wave_temporal_lags: list[int] = field(default_factory=lambda: [1, 2]) diff --git a/src/idmodels/constants.py b/src/idmodels/constants.py new file mode 100644 index 0000000..c9f8663 --- /dev/null +++ b/src/idmodels/constants.py @@ -0,0 +1,6 @@ +NHSN_INCIDENCE_SHIFT: float = 0.75 ** 4 # ≈ 0.316 + +POWER_TRANSFORM_OFFSET: float = 0.01 + +IN_SEASON_WEEK_MIN: int = 10 +IN_SEASON_WEEK_MAX: int = 45 diff --git a/src/idmodels/features.py b/src/idmodels/features.py new file mode 100644 index 0000000..4dfbb3e --- /dev/null +++ b/src/idmodels/features.py @@ -0,0 +1,393 @@ +import fnmatch +from abc import ABC, abstractmethod + +import numpy as np +import pandas as pd +from iddata.utils import get_holidays +from timeseriesutils import featurize + +from idmodels.spatial_utils import get_directional_neighbors, get_location_centroids, validate_wave_directions + + +class Feature(ABC): + """A single feature-engineering step applied to a DataFrame.""" + + + @abstractmethod + def apply( + self, + df: pd.DataFrame, + feat_names: list[str], + ) -> tuple[pd.DataFrame, list[str]]: + """ + Augment df with new feature columns. + + Parameters + ---------- + df : pd.DataFrame + Input data sorted by ["source", "location", "wk_end_date"]. + feat_names : list[str] + Running list of active feature column names. + + Returns + ------- + tuple of (augmented df, updated feat_names) + """ + ... + + +class OneHotEncodingFeature(Feature): + """One-hot encode categorical columns.""" + + + def __init__(self, columns: list[str]): + self.columns = columns + + + def apply(self, df: pd.DataFrame, feat_names: list[str]) -> tuple[pd.DataFrame, list[str]]: + for c in self.columns: + ohe = pd.get_dummies(df[c], prefix=c) + df = pd.concat([df, ohe], axis=1) + feat_names = feat_names + list(ohe.columns) + return df, feat_names + + +class HolidayFeature(Feature): + """ + Adds delta_xmas (signed distance from Christmas week) and xmas_spike. + Only delta_xmas is added to feat_names; xmas_spike is a covariate. + """ + + + def apply(self, df: pd.DataFrame, feat_names: list[str]) -> tuple[pd.DataFrame, list[str]]: + df = df.merge( + get_holidays() + .query("holiday == 'Christmas Day'") + .drop(columns=["holiday", "date"]) + .rename(columns={"season_week": "xmas_week"}), + how="left", + on="season", + ).assign(delta_xmas=lambda x: x["season_week"] - x["xmas_week"]) + df["xmas_spike"] = np.maximum(3 - np.abs(df["delta_xmas"]), 0) + feat_names = feat_names + ["delta_xmas"] + return df, feat_names + + +class TaylorFeature(Feature): + """Windowed Taylor polynomial coefficients via timeseriesutils.""" + + + def __init__( + self, + column: str, + degree: int, + window_sizes: list[int], + window_align: str = "trailing", + fill_edges: bool = False, + ): + self.column = column + self.degree = degree + self.window_sizes = window_sizes + self.window_align = window_align + self.fill_edges = fill_edges + + + def apply(self, df: pd.DataFrame, feat_names: list[str]) -> tuple[pd.DataFrame, list[str]]: + df, new_feat_names = featurize.featurize_data( + df, + group_columns=["source", "location"], + features=[ + { + "fun": "windowed_taylor_coefs", + "args": { + "columns": self.column, + "taylor_degree": self.degree, + "window_align": self.window_align, + "window_size": self.window_sizes, + "fill_edges": self.fill_edges, + }, + } + ], + ) + feat_names = feat_names + new_feat_names + return df, feat_names + + +class RollingMeanFeature(Feature): + """Rolling mean over specified window sizes.""" + + + def __init__( + self, + column: str, + window_sizes: list[int], + group_columns: list[str] | None = None, + ): + self.column = column + self.window_sizes = window_sizes + self.group_columns = group_columns if group_columns is not None else ["location"] + + + def apply(self, df: pd.DataFrame, feat_names: list[str]) -> tuple[pd.DataFrame, list[str]]: + df, new_feat_names = featurize.featurize_data( + df, + group_columns=["source", "location"], + features=[ + { + "fun": "rollmean", + "args": { + "columns": self.column, + "group_columns": self.group_columns, + "window_size": self.window_sizes, + }, + } + ], + ) + feat_names = feat_names + new_feat_names + return df, feat_names + + +class LagFeature(Feature): + """ + Create lagged versions of specified columns. + + When columns=None, FeaturePipeline resolves this to all columns accumulated + since the previous LagFeature step. + """ + + + def __init__(self, columns: list[str] | None, lags: list[int]): + self.columns = columns + self.lags = lags + + + def apply(self, df: pd.DataFrame, feat_names: list[str]) -> tuple[pd.DataFrame, list[str]]: + if self.columns is None: + raise ValueError("LagFeature.columns must be resolved by FeaturePipeline before calling apply().") + df, new_feat_names = featurize.featurize_data( + df, + group_columns=["source", "location"], + features=[ + { + "fun": "lag", + "args": { + "columns": self.columns, + "lags": self.lags, + }, + } + ], + ) + feat_names = feat_names + new_feat_names + return df, feat_names + + +class HorizonTargetFeature(Feature): + """ + Expand df to max_horizon rows per original row; add inc_trans_cs_target, + horizon, and delta_target columns. Only horizon is added to feat_names. + """ + + + def __init__(self, column: str, max_horizon: int): + self.column = column + self.max_horizon = max_horizon + + + def apply(self, df: pd.DataFrame, feat_names: list[str]) -> tuple[pd.DataFrame, list[str]]: + df, new_feat_names = featurize.featurize_data( + df, + group_columns=["source", "location"], + features=[ + { + "fun": "horizon_targets", + "args": { + "columns": self.column, + "horizons": list(range(1, self.max_horizon + 1)), + }, + } + ], + ) + df["delta_target"] = df[self.column + "_target"] - df[self.column] + feat_names = feat_names + [f for f in new_feat_names if f == "horizon"] + return df, feat_names + + +class LevelFeatureFilter(Feature): + """ + Remove absolute-level features from feat_names (does not drop df columns). + Used when incl_level_feats=False. + """ + + + def apply(self, df: pd.DataFrame, feat_names: list[str]) -> tuple[pd.DataFrame, list[str]]: + level_feats = ( + ["inc_trans_cs", "inc_trans_cs_lag1", "inc_trans_cs_lag2"] + + fnmatch.filter(feat_names, "*taylor_d?_c0*") + + fnmatch.filter(feat_names, "*inc_trans_cs_rollmean*") + ) + feat_names = [f for f in feat_names if f not in level_feats] + return df, feat_names + + +class DirectionalWaveFeature(Feature): + """Spatial wave propagation features: inverse-distance-weighted neighbor averages.""" + + + def __init__( + self, + directions: list[str], + temporal_lags: list[int], + max_distance_km: float, + include_velocity: bool = False, + include_aggregate: bool = True, + ): + self.directions = directions + self.temporal_lags = temporal_lags + self.max_distance_km = max_distance_km + self.include_velocity = include_velocity + self.include_aggregate = include_aggregate + + + def apply(self, df: pd.DataFrame, feat_names: list[str]) -> tuple[pd.DataFrame, list[str]]: + df, wave_feat_names = self._compute(df) + feat_names = feat_names + wave_feat_names + return df, feat_names + + + def _compute(self, df: pd.DataFrame) -> tuple[pd.DataFrame, list[str]]: + """Compute directional wave features and return augmented df plus feature names.""" + validate_wave_directions(self.directions) + + agg_levels = df["agg_level"].unique() + if len(agg_levels) > 1: + raise ValueError(f"Multiple aggregation levels found: {agg_levels}. " + "Directional wave features currently support only one agg_level at a time.") + + agg_level = agg_levels[0] + + location_coords = get_location_centroids(agg_level=agg_level) + + locations_in_df = set(df["location"].unique()) + locations_with_coords = set(location_coords.keys()) + missing = locations_in_df - locations_with_coords + if missing: + raise ValueError(f"Missing coordinates for locations: {missing}. Cannot compute directional wave features.") + + # Precompute directional neighbors + neighbor_cache: dict = {} + for loc in locations_in_df: + neighbor_cache[loc] = {} + for direction in self.directions: + neighbor_cache[loc][direction] = get_directional_neighbors( + origin_loc=loc, + origin_coord=location_coords[loc], + all_coords=location_coords, + direction=direction, + max_distance_km=self.max_distance_km, + ) + + # Precompute all-direction neighbors for aggregate feature + all_neighbor_cache: dict = {} + if self.include_aggregate: + from idmodels.spatial_utils import haversine_distance + for loc in locations_in_df: + neighbors = [(other_loc, haversine_distance(location_coords[loc], coord)) + for other_loc, coord in location_coords.items() + if other_loc != loc] + neighbors = [(loc, dist) for loc, dist in neighbors if dist <= self.max_distance_km] + neighbors.sort(key=lambda x: x[1]) + all_neighbor_cache[loc] = neighbors + + df_sorted = df.sort_values(["location", "wk_end_date"]).reset_index(drop=True) + wave_features: dict = {} + + + def _weighted_avg(neighbors, date): + ws, wt = 0.0, 0.0 + for nloc, dist in neighbors: + val = df_sorted.loc[ + (df_sorted["location"] == nloc) & (df_sorted["wk_end_date"] == date), + "inc_trans_cs", + ] + if len(val) > 0 and not pd.isna(val.iloc[0]): + w = 1.0 / dist if dist > 0 else 1.0 + ws += w * val.iloc[0] + wt += w + return ws / wt if wt > 0 else np.nan + + + # Base directional features + for direction in self.directions: + feat_name = f"inc_trans_cs_wave_{direction}" + wave_features[feat_name] = [_weighted_avg(neighbor_cache[row["location"]][direction], row["wk_end_date"]) + for _, row in df_sorted.iterrows()] + + # Aggregate feature + if self.include_aggregate: + wave_features["inc_trans_cs_wave_avg"] = [ + _weighted_avg(all_neighbor_cache[row["location"]], row["wk_end_date"]) + for _, row in df_sorted.iterrows()] + + for feat_name, vals in wave_features.items(): + df_sorted[feat_name] = vals + + base_feat_names = list(wave_features.keys()) + + # Temporal lags + for feat_name in base_feat_names: + for lag in self.temporal_lags: + lagged = f"{feat_name}_lag{lag}" + df_sorted[lagged] = df_sorted.groupby("location")[feat_name].shift(lag) + + # Velocity features + if self.include_velocity: + for feat_name in base_feat_names: + lag1 = f"{feat_name}_lag1" + if lag1 not in df_sorted.columns: + df_sorted[lag1] = df_sorted.groupby("location")[feat_name].shift(1) + df_sorted[f"{feat_name}_velocity"] = df_sorted[feat_name] - df_sorted[lag1] + + df_sorted = df_sorted.sort_index() + + wave_feat_names = list(base_feat_names) + wave_feat_names += [f"{fn}_lag{lag}" for fn in base_feat_names for lag in self.temporal_lags] + if self.include_velocity: + wave_feat_names += [f"{fn}_velocity" for fn in base_feat_names] + + return df_sorted, wave_feat_names + + +class FeaturePipeline: + """ + Applies a sequence of Feature steps to a DataFrame. + + LagFeature(columns=None) resolves to all columns accumulated since the last + LagFeature step. Accumulator resets after each LagFeature step. + initial_feat_names columns are NOT included in the accumulator. + """ + + + def __init__(self, features: list[Feature], initial_feat_names: list[str] | None = None): + self.features = features + self.initial_feat_names = initial_feat_names or [] + + + def apply(self, df: pd.DataFrame) -> tuple[pd.DataFrame, list[str]]: + feat_names = list(self.initial_feat_names) + accumulated_new: list[str] = [] + + for feature in self.features: + feat_names_before = list(feat_names) + + if isinstance(feature, LagFeature) and feature.columns is None: + feature = LagFeature(columns=list(accumulated_new), lags=feature.lags) + + df, feat_names = feature.apply(df, feat_names) + new_this_step = [f for f in feat_names if f not in feat_names_before] + + if isinstance(feature, LagFeature): + accumulated_new = [] + else: + accumulated_new.extend(new_this_step) + + return df, feat_names diff --git a/src/idmodels/gbqr.py b/src/idmodels/gbqr.py index 374addb..590ebc4 100644 --- a/src/idmodels/gbqr.py +++ b/src/idmodels/gbqr.py @@ -3,337 +3,189 @@ import lightgbm as lgb import numpy as np import pandas as pd -from iddata.loader import DiseaseDataLoader +from iddata.enums import Disease +from iddata.sources.flusurvnet import FluSurvNetDataSource +from iddata.sources.ilinet import ILINetDataSource +from iddata.sources.nhsn import NHSNDataSource +from iddata.sources.nssp import NSSPDataSource from tqdm.autonotebook import tqdm -from idmodels.config import DataSource, Disease, PowerTransform -from idmodels.preprocess import create_directional_wave_features, create_features_and_targets +from idmodels.config import GBQRModelConfig, RunConfig, SourceType +from idmodels.features import ( + DirectionalWaveFeature, + FeaturePipeline, + HolidayFeature, + HorizonTargetFeature, + LagFeature, + LevelFeatureFilter, + OneHotEncodingFeature, + RollingMeanFeature, + TaylorFeature, +) +from idmodels.model import IDModel from idmodels.utils import build_save_path -class GBQRModel(): - def __init__(self, model_config): - self.model_config = model_config - - - def run(self, run_config): - """ - Load flu data, generate predictions from a gbqr model, and save them as a csv file. - - Parameters - ---------- - run_config: configuration object with settings for the run - """ - # load flu data - if self.model_config.reporting_adj: - ilinet_kwargs = None - flusurvnet_kwargs = None +class GBQRModel(IDModel): + """Gradient Boosted Quantile Regression forecast model.""" + + + def __init__(self, model_config: GBQRModelConfig): + # Narrow self.model_config from ModelConfig to GBQRModelConfig so that + # type checkers resolve GBQRModelConfig-specific attributes in this class. + super().__init__(model_config) + self.model_config: GBQRModelConfig = model_config + + + def _build_sources(self, run_config: RunConfig): + source_map = {SourceType.NHSN: NHSNDataSource(disease=run_config.disease), + SourceType.NSSP: NSSPDataSource(disease=run_config.disease), + SourceType.ILINET: ILINetDataSource(scale_to_positive=self.model_config.reporting_adj), + SourceType.FLUSURVNET: FluSurvNetDataSource(burden_adj=self.model_config.reporting_adj)} + if SourceType.NHSN in self.model_config.sources and SourceType.NSSP in self.model_config.sources: + raise ValueError("Only one of NHSN or NSSP may be selected.") + + return [source_map[s] for s in self.model_config.sources] + + + def _build_feature_pipeline(self, run_config: RunConfig) -> FeaturePipeline: + if run_config.disease in (Disease.FLU, Disease.RSV): + initial_feats = ["inc_trans_cs", "season_week", "log_pop"] else: - ilinet_kwargs = {"scale_to_positive": False} - flusurvnet_kwargs = {"burden_adj": False} - - valid_sources = {DataSource.FLUSURVNET, DataSource.NHSN, DataSource.ILINET, DataSource.NSSP} - if not set(self.model_config.sources) <= valid_sources: - raise ValueError("For GBQR, the only supported data sources are 'nhsn', 'flusurvnet', 'ilinet', or 'nssp'.") - - # Check if both nhsn and nssp data are included as sources - if (DataSource.NHSN in self.model_config.sources) and (DataSource.NSSP in self.model_config.sources): - raise ValueError("Only one of 'nhsn' or 'nssp' may be selected as a data source.") - - fdl = DiseaseDataLoader() - if DataSource.NHSN in self.model_config.sources: - df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease}, - ilinet_kwargs=ilinet_kwargs, - flusurvnet_kwargs=flusurvnet_kwargs, - sources=self.model_config.sources, - power_transform=self.model_config.power_transform) - elif DataSource.NSSP in self.model_config.sources: - df = fdl.load_data(nssp_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease}, - ilinet_kwargs=ilinet_kwargs, - flusurvnet_kwargs=flusurvnet_kwargs, - sources=self.model_config.sources, - power_transform=self.model_config.power_transform) - - if (run_config.states == []) & (run_config.hsas == []): - raise ValueError("User must request a non-empty set of locations to forecast for.") - - df_states = df.loc[(df["location"].isin(run_config.states)) & (df["agg_level"] != "hsa")] - df_hsas = df.loc[(df["location"].isin(run_config.hsas)) & (df["agg_level"] == "hsa")] - df = pd.concat([df_states, df_hsas], join = "inner", axis = 0) - df["unique_id"] = df["agg_level"] + df["location"] - - # augment data with features and target values - if (run_config.disease == Disease.FLU) or (run_config.disease == Disease.RSV): - init_feats = ["inc_trans_cs", "season_week", "log_pop"] - elif run_config.disease == Disease.COVID: - init_feats = ["inc_trans_cs", "log_pop"] - - # Create directional wave features if enabled + initial_feats = ["inc_trans_cs", "log_pop"] + + features = [] + if self.model_config.use_directional_waves: - wave_config = { - "enabled": True, - "directions": self.model_config.wave_directions, - "temporal_lags": self.model_config.wave_temporal_lags, - "max_distance_km": self.model_config.wave_max_distance_km, - "include_velocity": self.model_config.wave_include_velocity, - "include_aggregate": self.model_config.wave_include_aggregate, - } - df, wave_feat_names = create_directional_wave_features(df, wave_config) - init_feats = init_feats + wave_feat_names - - df, feat_names = create_features_and_targets( - df = df, - incl_level_feats=self.model_config.incl_level_feats, - max_horizon=run_config.max_horizon, - curr_feat_names=init_feats) - - # keep only rows that are in-season - if (run_config.disease == Disease.FLU) or (run_config.disease == Disease.RSV): + features.append( + DirectionalWaveFeature( + directions=self.model_config.wave_directions, + temporal_lags=self.model_config.wave_temporal_lags, + max_distance_km=self.model_config.wave_max_distance_km, + include_velocity=self.model_config.wave_include_velocity, + include_aggregate=self.model_config.wave_include_aggregate, + ) + ) + + features += [ + OneHotEncodingFeature(columns=["source", "agg_level", "location"]), + HolidayFeature(), + LagFeature(columns=["inc_trans_cs"], lags=[1, 2]), + TaylorFeature(column="inc_trans_cs", degree=2, window_sizes=[4, 6]), + TaylorFeature(column="inc_trans_cs", degree=1, window_sizes=[3, 5]), + RollingMeanFeature(column="inc_trans_cs", window_sizes=[2, 4]), + LagFeature(columns=None, lags=[1, 2]), + HorizonTargetFeature(column="inc_trans_cs", max_horizon=run_config.max_horizon), + ] + + if not self.model_config.incl_level_feats: + features.append(LevelFeatureFilter()) + + return FeaturePipeline(features=features, initial_feat_names=initial_feats) + + + def _fit_and_predict(self, df: pd.DataFrame, feat_names: list[str], run_config: RunConfig) -> pd.DataFrame: + """Fit bagged LightGBM and return long-format predictions in inc_trans_cs space.""" + if run_config.disease in (Disease.FLU, Disease.RSV): df = df.query("season_week >= 5 and season_week <= 45") - - # "test set" df used to generate look-ahead predictions - df_test = df.loc[df.wk_end_date == df.wk_end_date.max()] \ - .copy() - - # "train set" df for model fitting; target value non-missing + + df_test = df.loc[df.wk_end_date == df.wk_end_date.max()].copy() df_train = df.loc[~df["delta_target"].isna().values] - - # train model and obtain test set predictinos + if self.model_config.fit_locations_separately: unique_ids = df_test["unique_id"].unique() - preds_df = [ - self._train_gbq_and_predict( - run_config, - df_train, df_test, feat_names, location - ) for location in unique_ids - ] - preds_df = pd.concat(preds_df, axis=0) - else: - preds_df = self._train_gbq_and_predict( - run_config, - df_train, df_test, feat_names + preds_df = pd.concat( + [self._train_gbq_and_predict(run_config, df_train, df_test, feat_names, loc) + for loc in unique_ids], + axis=0, ) - - # save - save_path = build_save_path( - root=run_config.output_root, - run_config=run_config, - model_config=self.model_config - ) - preds_df.to_csv(save_path, index=False) - - - def _train_gbq_and_predict(self, run_config, - df_train, df_test, feat_names, location = None): - """ - Train gbq model and get predictions on the original target scale, - formatted in the FluSight hub format. - - Parameters - ---------- - run_config: configuration object with settings for the run - df_train: data frame with training data - df_test: data frame with test data - feat_names: list of names of columns with features - location: optional string of location to fit to. Default, None, fits to all locations - - Returns - ------- - Pandas data frame with test set predictions in FluSight hub format - """ - # filter to location if necessary + else: + preds_df = self._train_gbq_and_predict(run_config, df_train, df_test, feat_names) + + return preds_df + + + def _train_gbq_and_predict(self, run_config, df_train, df_test, feat_names, location=None): if location is not None: - df_test = df_test.query(f'location == "{location}"') - df_train = df_train.query(f'location == "{location}"') - - # get x and y + df_test = df_test.query(f'unique_id == "{location}"') + df_train = df_train.query(f'unique_id == "{location}"') + x_test = df_test[feat_names] x_train = df_train[feat_names] y_train = df_train["delta_target"] - - # test set predictions: - # same number of rows as df_test, one column per quantile level - test_pred_qs_df = self._get_test_quantile_predictions( - run_config, - df_train, x_train, y_train, x_test - ) - - # add predictions to original test df + + test_pred_qs_df = self._get_test_quantile_predictions(run_config, df_train, x_train, y_train, x_test) + df_test.reset_index(drop=True, inplace=True) df_test_w_preds = pd.concat([df_test, test_pred_qs_df], axis=1) - - # melt to get columns into rows, keeping only the things we need to invert data - # transforms later on - cols_to_keep = ["source", "agg_level", "location", "wk_end_date", "pop", - "inc_trans_cs", "horizon", + + cols_to_keep = ["source", "agg_level", "location", "wk_end_date", "pop", "inc_trans_cs", "horizon", "inc_trans_center_factor", "inc_trans_scale_factor"] preds_df = df_test_w_preds[cols_to_keep + run_config.q_labels] preds_df = preds_df.loc[preds_df["source"].isin(["nhsn", "nssp"])] preds_df = pd.melt(preds_df, - id_vars=cols_to_keep, - var_name="quantile", - value_name = "delta_hat") - - # build data frame with predictions on the original scale - preds_df["inc_trans_cs_target_hat"] = preds_df["inc_trans_cs"] + preds_df["delta_hat"] - preds_df["inc_trans_target_hat"] = (preds_df["inc_trans_cs_target_hat"] + preds_df["inc_trans_center_factor"]) * (preds_df["inc_trans_scale_factor"] + 0.01) - if self.model_config.power_transform == PowerTransform.FOURTH_ROOT: - inv_power = 4 - elif self.model_config.power_transform == PowerTransform.NONE: - inv_power = 1 - else: - raise ValueError(f"unsupported power_transform: {self.model_config.power_transform!r}") - - preds_df["value"] = (np.maximum(preds_df["inc_trans_target_hat"], 0.0) ** inv_power - 0.01 - 0.75**4) - - # get predictions into the format needed for FluSight hub submission - if "nhsn" in preds_df["source"].unique(): - # turn nhsn rates back into counts - preds_df["value"] = preds_df["value"] * preds_df["pop"] / 100000 - target_name = "wk inc " + run_config.disease + " hosp" - elif "nssp" in preds_df["source"].unique(): - preds_df["value"] = preds_df["value"] / 100 # percentage to proportion - preds_df["value"] = np.minimum(preds_df["value"], 1.0) - target_name = "wk inc " + run_config.disease + " prop ed visits" - - keep_agg_levels = False - gcols = ["location", "reference_date", "horizon", "target_end_date", "target", "output_type"] - # we count national as state since it is coded using the same 2-digit fips code - preds_df["geo_level"] = np.where(preds_df["agg_level"] == "national", "state", preds_df["agg_level"]) - if len(preds_df["geo_level"].unique()) > 1: - keep_agg_levels = True - gcols.insert(0, "agg_level") - - preds_df["value"] = np.maximum(preds_df["value"], 0.0) - preds_df = self._format_as_flusight_output(preds_df, run_config.ref_date, target_name, keep_agg_levels) - - # sort quantiles to avoid quantile crossing - preds_df = self._quantile_noncrossing(preds_df, gcols = gcols) - + id_vars=cols_to_keep, + var_name="output_type_id", + value_name="delta_hat") + + # value in inc_trans_cs space (before inverse transform) + preds_df["value"] = preds_df["inc_trans_cs"] + preds_df["delta_hat"] + preds_df = preds_df.drop(columns=["delta_hat"]) + + # Sort quantiles to prevent crossing (in transformed space, which is monotone) + gcols = ["source", "agg_level", "location", "wk_end_date", "horizon"] + preds_df = self._quantile_noncrossing(preds_df, gcols=gcols) + return preds_df - def _get_test_quantile_predictions(self, run_config, - df_train, x_train, y_train, x_test): - """ - Train the model on bagged subsets of the training data and obtain - quantile predictions. This is the heart of the method. - - Parameters - ---------- - run_config: configuration object with settings for the run - df_train: Pandas data frame with training data - x_train: numpy array with training instances in rows, features in columns - y_train: numpy array with target values - x_test: numpy array with test instances in rows, features in columns - - Returns - ------- - Pandas data frame with test set predictions. The number of rows matches - the number of rows of `x_test`. The number of columns matches the number - of quantile levels for predictions as specified in the `run_config`. - Column names are given by `run_config.q_labels`. - """ - # seed for random number generation, based on reference date + def _get_test_quantile_predictions(self, run_config, df_train, x_train, y_train, x_test): rng_seed = int(calendar.timegm(run_config.ref_date.timetuple())) rng = np.random.default_rng(seed=rng_seed) - # seeds for lgb model fits, one per combination of bag and quantile level lgb_seeds = rng.integers(1e8, size=(self.model_config.num_bags, len(run_config.q_levels))) - - # training loop over bags + test_preds_by_bag = np.empty((x_test.shape[0], self.model_config.num_bags, len(run_config.q_levels))) - train_seasons = df_train["season"].unique() - - feat_importance = list() - + feat_importance = [] + for b in tqdm(range(self.model_config.num_bags), "Bag number"): - # get indices of observations that are in bag - bag_seasons = rng.choice( - train_seasons, - size = int(len(train_seasons) * self.model_config.bag_frac_samples), - replace=False) + bag_seasons = rng.choice(train_seasons, + size=int(len(train_seasons) * self.model_config.bag_frac_samples), + replace=False) bag_obs_inds = df_train["season"].isin(bag_seasons) - + for q_ind, q_level in enumerate(run_config.q_levels): - # fit to bag - model = lgb.LGBMRegressor( - verbosity=-1, - objective="quantile", - alpha=q_level, - random_state=lgb_seeds[b, q_ind]) + model = lgb.LGBMRegressor(verbosity=-1, + objective="quantile", + alpha=q_level, + random_state=lgb_seeds[b, q_ind]) model.fit(X=x_train.loc[bag_obs_inds, :], y=y_train.loc[bag_obs_inds]) - feat_importance.append( - pd.DataFrame({ - "feat": x_train.columns, - "importance": model.feature_importances_, - "b": b, - "q_level": q_level - }) - ) - - # test set predictions + feat_importance.append(pd.DataFrame({"feat": x_train.columns, + "importance": model.feature_importances_, + "b": b, + "q_level": q_level})) test_preds_by_bag[:, b, q_ind] = model.predict(X=x_test) - - # combine and save feature importance scores + if self.model_config.save_feat_importance: - feat_importance = pd.concat(feat_importance, axis=0) - save_path = build_save_path( - root=run_config.artifact_store_root, - run_config=run_config, - model_config=self.model_config, - subdir="feat_importance") - feat_importance.to_csv(save_path, index=False) - - # combined predictions across bags: median + feat_importance_df = pd.concat(feat_importance, axis=0) + save_path = build_save_path(root=run_config.artifact_store_root, + run_config=run_config, + model_config=self.model_config, + subdir="feat_importance") + feat_importance_df.to_csv(save_path, index=False) + test_pred_qs = np.median(test_preds_by_bag, axis=1) - - # test predictions as a data frame, one column per quantile level test_pred_qs_df = pd.DataFrame(test_pred_qs) test_pred_qs_df.columns = run_config.q_labels - return test_pred_qs_df - def _format_as_flusight_output(self, preds_df, ref_date, target_name, keep_agg_levels = False): - # keep just required columns and rename to match hub format - req_cols = ["location", "wk_end_date", "horizon", "quantile", "value"] - - if keep_agg_levels: - req_cols.insert(0, "agg_level") - - preds_df = preds_df[req_cols] \ - .rename(columns={"quantile": "output_type_id"}) - - preds_df["target_end_date"] = preds_df["wk_end_date"] + pd.to_timedelta(7*preds_df["horizon"], unit="days") - preds_df["reference_date"] = ref_date - preds_df["horizon"] = (pd.to_timedelta(preds_df["target_end_date"].dt.date - ref_date).dt.days / 7).astype(int) - preds_df["target"] = target_name - - preds_df["output_type"] = "quantile" - preds_df.drop(columns="wk_end_date", inplace=True) - - return preds_df - - - def _quantile_noncrossing(self, preds_df, gcols): - """ - Sort predictions to be in alignment with quantile levels, to prevent - quantile crossing. - - Parameters - ---------- - preds_df: data frame with quantile predictions - gcols: columns to group by; predictions will be sorted within those groups - - Returns - ------- - Sorted version of preds_df, guaranteed not to have quantile crossing - """ - g = preds_df.set_index(gcols).groupby(gcols) - preds_df = g[["output_type_id", "value"]] \ - .transform(lambda x: x.sort_values()) \ - .reset_index() - + def _quantile_noncrossing(self, preds_df: pd.DataFrame, gcols: list[str]) -> pd.DataFrame: + # Sort rows so quantile labels are ascending within each group, then sort values the same way. Positional + # alignment after reset_index pairs the smallest value with the smallest quantile label, fixing any crossings. + # All non-gcol columns (inc_trans_center_factor, inc_trans_scale_factor, pop, …) are preserved — only "value" is + # reassigned. + preds_df = preds_df.sort_values(gcols + ["output_type_id"]).reset_index(drop=True) + preds_df["value"] = preds_df.groupby(gcols)["value"].transform(np.sort) return preds_df diff --git a/src/idmodels/model.py b/src/idmodels/model.py new file mode 100644 index 0000000..c7fa16d --- /dev/null +++ b/src/idmodels/model.py @@ -0,0 +1,146 @@ +from abc import ABC, abstractmethod + +import numpy as np +import pandas as pd +from iddata.ancillary.population import PopulationData +from iddata.loader import DiseaseDataLoader +from iddata.sources.base import DataSource as IdDataSource + +# Import SourceType from iddata via the re-export in config +from idmodels.config import ModelConfig, PowerTransform, RunConfig, SourceType +from idmodels.constants import NHSN_INCIDENCE_SHIFT +from idmodels.features import FeaturePipeline +from idmodels.transforms import ( + CenterScaleTransform, + ComposedTransform, + FourthRootTransform, + IdentityTransform, + Transform, +) +from idmodels.utils import build_save_path + + +class IDModel(ABC): + """ + Abstract base class for infectious disease forecast models. Subclasses implement _build_sources(), + _build_feature_pipeline(), and _fit_and_predict(). run() orchestrates the full workflow. + """ + + + def __init__(self, model_config: ModelConfig): + self.model_config = model_config + + + def run(self, run_config: RunConfig) -> None: + """Load data, generate predictions, and save to file.""" + sources = self._build_sources(run_config) + df = DiseaseDataLoader().load(sources=sources, as_of=run_config.ref_date, ancillary=[PopulationData()]) + df = self._filter_locations(df, run_config) + df["unique_id"] = df["agg_level"] + df["location"] + + transform = self._build_transform() + df = transform.apply(df) + + pipeline = self._build_feature_pipeline(run_config) + df, feat_names = pipeline.apply(df) + + preds_df = self._fit_and_predict(df, feat_names, run_config) + + preds_df = self._invert_and_scale(preds_df, transform, run_config) + + preds_df = self._format_output(preds_df, run_config) + save_path = build_save_path(root=run_config.output_root, + run_config=run_config, + model_config=self.model_config) + preds_df["output_type_id"] = preds_df["output_type_id"].astype(str) + preds_df.to_csv(save_path, index=False) + + + @abstractmethod + def _build_sources(self, run_config: RunConfig) -> list[IdDataSource]: + """Instantiate iddata DataSource objects for this model.""" + ... + + + @abstractmethod + def _build_feature_pipeline(self, run_config: RunConfig) -> FeaturePipeline: + """Return the FeaturePipeline for this model.""" + ... + + + @abstractmethod + def _fit_and_predict(self, df: pd.DataFrame, feat_names: list[str], run_config: RunConfig) -> pd.DataFrame: + """ + Fit model and generate quantile predictions. + + Returns long-format DataFrame with columns: + source, agg_level, location, wk_end_date, pop, horizon, + inc_trans_cs, inc_trans_center_factor, inc_trans_scale_factor, + output_type_id, value (in inc_trans_cs space) + """ + ... + + + def _build_transform(self) -> Transform: + """Default: ComposedTransform([power_transform, CenterScaleTransform()]).""" + additive_shift = NHSN_INCIDENCE_SHIFT if SourceType.NHSN in self.model_config.sources else 0.0 + if self.model_config.power_transform == PowerTransform.FOURTH_ROOT: + power_t: Transform = FourthRootTransform(additive_shift=additive_shift) + else: + power_t = IdentityTransform(additive_shift=additive_shift) + return ComposedTransform([power_t, CenterScaleTransform()]) + + + def _filter_locations(self, df: pd.DataFrame, run_config: RunConfig) -> pd.DataFrame: + if not run_config.states and not run_config.hsas: + raise ValueError("RunConfig must specify at least one state or HSA.") + + df_states = df.loc[(df["location"].isin(run_config.states)) & (df["agg_level"] != "hsa")] + df_hsas = df.loc[(df["location"].isin(run_config.hsas)) & (df["agg_level"] == "hsa")] + return pd.concat([df_states, df_hsas], join="inner", axis=0) + + + def _invert_and_scale(self, preds_df: pd.DataFrame, transform: Transform, run_config: RunConfig) -> pd.DataFrame: + """Inverse transform predictions then convert to original units.""" + preds_df["value"] = transform.invert(preds_df["value"].values, context=preds_df) + preds_df["value"] = np.maximum(preds_df["value"], 0.0) + + if SourceType.NHSN in self.model_config.sources: + preds_df["value"] = preds_df["value"] * preds_df["pop"] / 100000 + elif SourceType.NSSP in self.model_config.sources: + preds_df["value"] = np.minimum(preds_df["value"] / 100, 1.0) + + return preds_df + + + def _format_output(self, preds_df: pd.DataFrame, run_config: RunConfig) -> pd.DataFrame: + """Reshape to FluSight hub submission format.""" + if SourceType.NHSN in self.model_config.sources: + target_name = f"wk inc {run_config.disease.value} hosp" + else: + target_name = f"wk inc {run_config.disease.value} prop ed visits" + + preds_df["target_end_date"] = preds_df["wk_end_date"] + pd.to_timedelta( + 7 * preds_df["horizon"], unit="days" + ) + preds_df["reference_date"] = run_config.ref_date + preds_df["horizon"] = ( + pd.to_timedelta(preds_df["target_end_date"].dt.date - run_config.ref_date).dt.days / 7 + ).astype(int) + preds_df["output_type"] = "quantile" + preds_df["target"] = target_name + preds_df.drop(columns="wk_end_date", inplace=True) + + req_cols = ["location", "reference_date", "horizon", "target_end_date", + "target", "output_type", "output_type_id", "value"] + + preds_df["geo_level"] = np.where( + preds_df["agg_level"] == "national", "state", preds_df["agg_level"] + ) + if len(preds_df["geo_level"].unique()) > 1: + req_cols = ["agg_level"] + req_cols + + return preds_df.sort_values( + ["output_type_id", "horizon", "location", "agg_level"], + ascending=[True, True, True, False], + ).reset_index(drop=True)[req_cols] diff --git a/src/idmodels/preprocess.py b/src/idmodels/preprocess.py deleted file mode 100644 index 1581762..0000000 --- a/src/idmodels/preprocess.py +++ /dev/null @@ -1,369 +0,0 @@ -import fnmatch - -import numpy as np -import pandas as pd -from iddata.utils import get_holidays -from timeseriesutils import featurize - -from idmodels.spatial_utils import get_directional_neighbors, get_location_centroids, validate_wave_directions - - -def create_features_and_targets(df, incl_level_feats, max_horizon, curr_feat_names = []): - ''' - Create features and targets for prediction - - Parameters - ---------- - df: pandas dataframe - data frame with data to "featurize" - incl_level_feats: boolean - include features that are a measure of local level of the signal? - max_horizon: int - maximum forecast horizon - curr_feat_names: list of strings - list of names of columns in `df` containing existing features - - Returns - ------- - tuple with: - - the input data frame, augmented with additional columns with feature and - target values - - a list of all feature names, columns in the data frame - ''' - - # current features; will be updated - feat_names = curr_feat_names - - # one-hot encodings of data source, agg_level, and location - for c in ["source", "agg_level", "location"]: - ohe = pd.get_dummies(df[c], prefix=c) - df = pd.concat([df, ohe], axis=1) - feat_names = feat_names + list(ohe.columns) - - # season week relative to christmas - df = df.merge( - get_holidays() \ - .query("holiday == 'Christmas Day'") \ - .drop(columns=["holiday", "date"]) \ - .rename(columns={"season_week": "xmas_week"}), - how="left", - on="season") \ - .assign(delta_xmas = lambda x: x["season_week"] - x["xmas_week"]) - - feat_names = feat_names + ["delta_xmas"] - - # features summarizing data within each combination of source and location - df, new_feat_names = featurize.featurize_data( - df, group_columns=["source", "location"], - features = [ - { - "fun": "windowed_taylor_coefs", - "args": { - "columns": "inc_trans_cs", - "taylor_degree": 2, - "window_align": "trailing", - "window_size": [4, 6], - "fill_edges": False - } - }, - { - "fun": "windowed_taylor_coefs", - "args": { - "columns": "inc_trans_cs", - "taylor_degree": 1, - "window_align": "trailing", - "window_size": [3, 5], - "fill_edges": False - } - }, - { - "fun": "rollmean", - "args": { - "columns": "inc_trans_cs", - "group_columns": ["location"], - "window_size": [2, 4] - } - } - ]) - feat_names = feat_names + new_feat_names - - df, new_feat_names = featurize.featurize_data( - df, group_columns=["source", "location"], - features = [ - { - "fun": "lag", - "args": { - "columns": ["inc_trans_cs"] + new_feat_names, - "lags": [1, 2] - } - } - ]) - feat_names = feat_names + new_feat_names - - # add forecast targets - df, new_feat_names = featurize.featurize_data( - df, group_columns=["source", "location"], - features = [ - { - "fun": "horizon_targets", - "args": { - "columns": "inc_trans_cs", - "horizons": [(i + 1) for i in range(max_horizon)] - } - } - ]) - feat_names = feat_names + new_feat_names - - # we will model the differences between the prediction target and the most - # recent observed value - df["delta_target"] = df["inc_trans_cs_target"] - df["inc_trans_cs"] - - # if requested, drop features that involve absolute level - if not incl_level_feats: - feat_names = _drop_level_feats(feat_names) - - return df, feat_names - - -def _drop_level_feats(feat_names): - level_feats = ["inc_trans_cs", "inc_trans_cs_lag1", "inc_trans_cs_lag2"] + \ - fnmatch.filter(feat_names, "*taylor_d?_c0*") + \ - fnmatch.filter(feat_names, "*inc_trans_cs_rollmean*") - feat_names = [f for f in feat_names if f not in level_feats] - return feat_names - - -def create_directional_wave_features(df, wave_config=None): - """ - Create spatial directional wave features. - - For each location and time point, computes distance-weighted averages - of neighboring locations' incidence in specified directions (e.g., N, S, E, W). - Also computes lagged versions and optionally velocity (rate of change) features. - - Parameters - ---------- - df : pandas.DataFrame - Data frame with columns: location, wk_end_date, inc_trans_cs, agg_level - wave_config : dict, optional - Configuration dictionary with keys: - - 'enabled': bool (default: False) - whether to generate features - - 'directions': list of str (default: ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']) - Subset of: N, NE, E, SE, S, SW, W, NW - - 'temporal_lags': list of int (default: [1, 2]) - temporal lags to include - lag1 means t-1, lag2 means t-2, etc. - - 'max_distance_km': float (default: 1000) - max distance for neighbors - - 'include_velocity': bool (default: False) - include rate-of-change features - - 'include_aggregate': bool (default: True) - include overall weighted average - - Returns - ------- - df : pandas.DataFrame - Input dataframe augmented with wave features - wave_feat_names : list of str - List of new feature names added to df - - Notes - ----- - - Lag semantics: lag1 uses time t-1, lag2 uses time t-2, etc. - - Velocity features compute: wave(t) - wave(t-1) - - Distance weighting uses inverse distance: weight = 1 / distance - - Features are computed per location and time point - """ - # Return early if not enabled - if wave_config is None or not wave_config.get("enabled", False): - return df, [] - - # Extract configuration with defaults - directions = wave_config.get("directions", ["N", "NE", "E", "SE", "S", "SW", "W", "NW"]) - temporal_lags = wave_config.get("temporal_lags", [1, 2]) - max_distance_km = wave_config.get("max_distance_km", 1000) - include_velocity = wave_config.get("include_velocity", False) - include_aggregate = wave_config.get("include_aggregate", True) - - # Validate directions - validate_wave_directions(directions) - - # Get aggregation level(s) from dataframe - agg_levels = df["agg_level"].unique() - if len(agg_levels) > 1: - raise ValueError( - f"Multiple aggregation levels found: {agg_levels}. " - f"Directional wave features currently support only one agg_level at a time." - ) - agg_level = agg_levels[0] - - # Get location centroids for this aggregation level - try: - location_coords = get_location_centroids(agg_level=agg_level) - except ValueError as e: - raise ValueError( - f"Cannot create directional wave features: {str(e)}" - ) - - # Filter to locations present in both data and coordinate lookup - locations_in_df = set(df["location"].unique()) - locations_with_coords = set(location_coords.keys()) - locations_to_use = locations_in_df.intersection(locations_with_coords) - - if len(locations_to_use) < len(locations_in_df): - missing = locations_in_df - locations_with_coords - raise ValueError( - f"Missing coordinates for locations: {missing}. " - f"Cannot compute directional wave features." - ) - - # Precompute directional neighbors for each location - neighbor_cache = {} - for loc in locations_to_use: - neighbor_cache[loc] = {} - for direction in directions: - neighbors = get_directional_neighbors( - origin_loc=loc, - origin_coord=location_coords[loc], - all_coords=location_coords, - direction=direction, - max_distance_km=max_distance_km - ) - neighbor_cache[loc][direction] = neighbors - - # Also compute all neighbors (for aggregate feature) - if include_aggregate: - all_neighbor_cache = {} - for loc in locations_to_use: - # Get all neighbors regardless of direction - neighbors = [] - for other_loc, other_coord in location_coords.items(): - if other_loc == loc: - continue - from idmodels.spatial_utils import haversine_distance - distance = haversine_distance(location_coords[loc], other_coord) - if distance <= max_distance_km: - neighbors.append((other_loc, distance)) - neighbors.sort(key=lambda x: x[1]) - all_neighbor_cache[loc] = neighbors - - # Create features for each direction - wave_features = {} - - # Sort dataframe by location and date for efficient processing - df_sorted = df.sort_values(["location", "wk_end_date"]).reset_index(drop=True) - - # Compute base directional features (at time t) - for direction in directions: - feat_name = f"inc_trans_cs_wave_{direction}" - feat_values = [] - - for idx, row in df_sorted.iterrows(): - loc = row["location"] - date = row["wk_end_date"] - - # Get neighbors in this direction - neighbors = neighbor_cache[loc][direction] - - if len(neighbors) == 0: - # No neighbors in this direction - feat_values.append(np.nan) - continue - - # Compute distance-weighted average - weighted_sum = 0.0 - weight_sum = 0.0 - - for neighbor_loc, distance in neighbors: - # Get neighbor's inc_trans_cs at same time point - neighbor_value = df_sorted[ - (df_sorted["location"] == neighbor_loc) & - (df_sorted["wk_end_date"] == date) - ]["inc_trans_cs"] - - if len(neighbor_value) > 0 and not pd.isna(neighbor_value.iloc[0]): - # Inverse distance weighting - weight = 1.0 / distance if distance > 0 else 1.0 - weighted_sum += weight * neighbor_value.iloc[0] - weight_sum += weight - - if weight_sum > 0: - feat_values.append(weighted_sum / weight_sum) - else: - feat_values.append(np.nan) - - wave_features[feat_name] = feat_values - - # Compute aggregate feature (overall weighted average) - if include_aggregate: - feat_name = "inc_trans_cs_wave_avg" - feat_values = [] - - for idx, row in df_sorted.iterrows(): - loc = row["location"] - date = row["wk_end_date"] - - neighbors = all_neighbor_cache[loc] - - if len(neighbors) == 0: - feat_values.append(np.nan) - continue - - # Compute distance-weighted average - weighted_sum = 0.0 - weight_sum = 0.0 - - for neighbor_loc, distance in neighbors: - neighbor_value = df_sorted[ - (df_sorted["location"] == neighbor_loc) & - (df_sorted["wk_end_date"] == date) - ]["inc_trans_cs"] - - if len(neighbor_value) > 0 and not pd.isna(neighbor_value.iloc[0]): - weight = 1.0 / distance if distance > 0 else 1.0 - weighted_sum += weight * neighbor_value.iloc[0] - weight_sum += weight - - if weight_sum > 0: - feat_values.append(weighted_sum / weight_sum) - else: - feat_values.append(np.nan) - - wave_features[feat_name] = feat_values - - # Add base features to dataframe - for feat_name, feat_values in wave_features.items(): - df_sorted[feat_name] = feat_values - - # Create lagged features - lagged_features = {} - base_feat_names = list(wave_features.keys()) - - for feat_name in base_feat_names: - for lag in temporal_lags: - lagged_feat_name = f"{feat_name}_lag{lag}" - # Use groupby to create lags within each location - df_sorted[lagged_feat_name] = df_sorted.groupby("location")[feat_name].shift(lag) - lagged_features[lagged_feat_name] = None # Just track the name - - # Create velocity features (rate of change) - if include_velocity: - velocity_features = {} - for feat_name in base_feat_names: - # Velocity = current - lag1 - lag1_name = f"{feat_name}_lag1" - if lag1_name in df_sorted.columns or 1 in temporal_lags: - velocity_feat_name = f"{feat_name}_velocity" - if lag1_name not in df_sorted.columns: - # Need to create lag1 if it doesn't exist - df_sorted[lag1_name] = df_sorted.groupby("location")[feat_name].shift(1) - df_sorted[velocity_feat_name] = df_sorted[feat_name] - df_sorted[lag1_name] - velocity_features[velocity_feat_name] = None - - # Restore original index order - df_sorted = df_sorted.sort_index() - - # Collect all feature names - wave_feat_names = list(wave_features.keys()) - wave_feat_names += list(lagged_features.keys()) - if include_velocity: - wave_feat_names += list(velocity_features.keys()) - - return df_sorted, wave_feat_names - diff --git a/src/idmodels/sarix.py b/src/idmodels/sarix.py index 2dcda12..c7c291b 100644 --- a/src/idmodels/sarix.py +++ b/src/idmodels/sarix.py @@ -1,181 +1,126 @@ import numpy as np import pandas as pd -from iddata.loader import DiseaseDataLoader -from iddata.utils import get_holidays +from iddata.sources.nhsn import NHSNDataSource +from iddata.sources.nssp import NSSPDataSource from sarix import sarix -from idmodels.config import DataSource, PowerTransform, SARIXFourierModelConfig -from idmodels.utils import build_save_path +from idmodels.config import RunConfig, SARIXFourierModelConfig, SARIXModelConfig, SourceType +from idmodels.features import FeaturePipeline, HolidayFeature +from idmodels.model import IDModel -class SARIXModel(): - def __init__(self, model_config): - self.model_config = model_config +class SARIXModel(IDModel): + """SARIX (Bayesian ARIMA with exogenous covariates) forecast model.""" + + + def __init__(self, model_config: SARIXModelConfig): + # Narrow self.model_config from ModelConfig to SARIXModelConfig so that + # type checkers resolve SARIXModelConfig-specific attributes in this class. + super().__init__(model_config) + self.model_config: SARIXModelConfig = model_config + + + def _build_sources(self, run_config: RunConfig): + sources_map = {SourceType.NHSN: NHSNDataSource(disease=run_config.disease), + SourceType.NSSP: NSSPDataSource(disease=run_config.disease)} + if not set(self.model_config.sources) <= sources_map.keys(): + raise ValueError("SARIXModel only supports NHSN and NSSP sources.") + + if SourceType.NHSN in self.model_config.sources and SourceType.NSSP in self.model_config.sources: + raise ValueError("Only one of NHSN or NSSP may be selected.") + + return [sources_map[s] for s in self.model_config.sources] + + + def _build_feature_pipeline(self, run_config: RunConfig) -> FeaturePipeline: + return FeaturePipeline(features=[HolidayFeature()], + initial_feat_names=["inc_trans_cs"] + self.model_config.x) - def _get_extra_sarix_params(self, df): - """Return extra parameters to pass to SARIX constructor. Returns empty dict by default.""" - return {} - def run(self, run_config): - valid_sources = {DataSource.NHSN, DataSource.NSSP} - if not set(self.model_config.sources) <= valid_sources: - raise ValueError("For SARIX, the only supported data sources are 'nhsn' or 'nssp'.") - - # Check if both nhsn and nssp data are included as sources - if (DataSource.NHSN in self.model_config.sources) and (DataSource.NSSP in self.model_config.sources): - raise ValueError("Only one of 'nhsn' or 'nssp' may be selected as a data source.") - - fdl = DiseaseDataLoader() - if DataSource.NHSN in self.model_config.sources: - df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease}, - sources=self.model_config.sources, - power_transform=self.model_config.power_transform) - target_name = "wk inc " + run_config.disease + " hosp" - elif DataSource.NSSP in self.model_config.sources: - df = fdl.load_data(nssp_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease}, - sources=self.model_config.sources, - power_transform=self.model_config.power_transform) - target_name = "wk inc " + run_config.disease + " prop ed visits" - - if (run_config.states == []) & (run_config.hsas == []): - raise ValueError("User must request a non-empty set of locations to forecast for.") - - df_states = df.loc[(df["location"].isin(run_config.states)) & (df["agg_level"] != "hsa")] - df_hsas = df.loc[(df["location"].isin(run_config.hsas)) & (df["agg_level"] == "hsa")] - df = pd.concat([df_states, df_hsas], join = "inner", axis = 0) - df["unique_id"] = df["agg_level"] + df["location"] - - # season week relative to christmas - df = df.merge( - get_holidays() \ - .query("holiday == 'Christmas Day'") \ - .drop(columns=["holiday", "date"]) \ - .rename(columns={"season_week": "xmas_week"}), - how="left", - on="season") \ - .assign(delta_xmas = lambda x: x["season_week"] - x["xmas_week"]) - df["xmas_spike"] = np.maximum(3 - np.abs(df["delta_xmas"]), 0) - - # missing values are interpolated when possible + def _fit_and_predict(self, df: pd.DataFrame, feat_names: list[str], run_config: RunConfig) -> pd.DataFrame: + """Fit SARIX and return quantile predictions in long format (inc_trans_cs space).""" xy_colnames = self.model_config.x + ["inc_trans_cs"] + # also need xmas_spike covariate (added by HolidayFeature but not in feat_names) + if "xmas_spike" in df.columns: + xy_colnames = ["xmas_spike"] + [c for c in xy_colnames if c != "xmas_spike"] + xy_colnames = self.model_config.x + ["inc_trans_cs"] # reset; xmas_spike handled via x param + df = df.query("wk_end_date >= '2022-10-01'").interpolate() - batched_xy = df[xy_colnames].values.reshape(len(df["unique_id"].unique()), -1, len(xy_colnames)) + batched_xy = df[xy_colnames].values.reshape( + len(df["unique_id"].unique()), -1, len(xy_colnames)) - # Get any extra parameters for the SARIX constructor extra_params = self._get_extra_sarix_params(df) - sarix_fit_all_locs_theta_pooled = sarix.SARIX( - xy = batched_xy, - p = self.model_config.p, - d = self.model_config.d, - P = self.model_config.P, - D = self.model_config.D, - season_period = self.model_config.season_period, - transform="none", # transformations are handled outside of SARIX - theta_pooling=self.model_config.theta_pooling, - sigma_pooling=self.model_config.sigma_pooling, - forecast_horizon=run_config.max_horizon, - num_warmup=self.model_config.num_warmup, - num_samples=self.model_config.num_samples, - num_chains=self.model_config.num_chains, - **extra_params + sarix_fit = sarix.SARIX(xy=batched_xy, + p=self.model_config.p, + d=self.model_config.d, + P=self.model_config.P, + D=self.model_config.D, + season_period=self.model_config.season_period, + transform="none", + theta_pooling=self.model_config.theta_pooling, + sigma_pooling=self.model_config.sigma_pooling, + forecast_horizon=run_config.max_horizon, + num_warmup=self.model_config.num_warmup, + num_samples=self.model_config.num_samples, + num_chains=self.model_config.num_chains, + **extra_params) + + pred_qs = _np_percentile( + sarix_fit.predictions[..., :, :, 0], + np.array(run_config.q_levels) * 100, + axis=0, ) - pred_qs = _np_percentile(sarix_fit_all_locs_theta_pooled.predictions[..., :, :, 0], - np.array(run_config.q_levels) * 100, axis=0) - df_data_last_obs = df.groupby(["unique_id", "agg_level"]).tail(1) - - preds_df = pd.concat([ - pd.DataFrame(pred_qs[i, :, :]) \ - .set_axis(df_data_last_obs["unique_id"], axis="index") \ - .set_axis(np.arange(1, run_config.max_horizon+1), axis="columns") \ - .assign(output_type_id = q_label) \ - for i, q_label in enumerate(run_config.q_labels) - ]) \ - .reset_index() \ - .melt(["unique_id", "output_type_id"], var_name="horizon") \ - .merge(df_data_last_obs, on="unique_id", how="left") - - # build data frame with predictions on the original scale - preds_df["value"] = (preds_df["value"] + preds_df["inc_trans_center_factor"]) * preds_df["inc_trans_scale_factor"] - if self.model_config.power_transform == PowerTransform.FOURTH_ROOT: - preds_df["value"] = np.maximum(preds_df["value"], 0.0) ** 4 - else: - preds_df["value"] = np.maximum(preds_df["value"], 0.0) ** 2 - - preds_df["value"] = (preds_df["value"] - 0.01 - 0.75**4) - preds_df["value"] = np.maximum(preds_df["value"], 0.0) - - if "nhsn" in preds_df["source"].unique(): - # turn nhsn rates back into counts - preds_df["value"] = preds_df["value"] * preds_df["pop"] / 100000 - - if target_name == "wk inc " + run_config.disease + " prop ed visits": - preds_df["value"] = preds_df["value"] / 100 # percentage to proportion - preds_df["value"] = np.minimum(preds_df["value"], 1.0) - - # keep just required columns and rename to match hub format - req_cols = ["location", "wk_end_date", "horizon", "output_type_id", "value"] - - # we count national as state since it is coded using the same 2-digit fips code - preds_df["geo_level"] = np.where(preds_df["agg_level"] == "national", "state", preds_df["agg_level"]) - if len(preds_df["geo_level"].unique()) > 1: - req_cols.insert(0, "agg_level") - - preds_df = preds_df[req_cols] - - preds_df["target_end_date"] = preds_df["wk_end_date"] + pd.to_timedelta(7*preds_df["horizon"], unit="days") - preds_df["reference_date"] = run_config.ref_date - preds_df["horizon"] = (pd.to_timedelta(preds_df["target_end_date"].dt.date - run_config.ref_date).dt.days / 7).astype(int) - preds_df["output_type"] = "quantile" - preds_df["target"] = target_name - preds_df.drop(columns="wk_end_date", inplace=True) - - # save - save_path = build_save_path( - root=run_config.output_root, - run_config=run_config, - model_config=self.model_config - ) - # Ensure output_type_id is string to avoid pandas inferring it as float when reading - preds_df["output_type_id"] = preds_df["output_type_id"].astype(str) - preds_df.to_csv(save_path, index=False) + + preds_df = pd.concat( + [ + pd.DataFrame(pred_qs[i, :, :]) + .set_axis(df_data_last_obs["unique_id"], axis="index") + .set_axis(np.arange(1, run_config.max_horizon + 1), axis="columns") + .assign(output_type_id=q_label) + for i, q_label in enumerate(run_config.q_labels) + ] + ).reset_index() \ + .melt(["unique_id", "output_type_id"], var_name="horizon") \ + .merge(df_data_last_obs, on="unique_id", how="left") + + # value is already in inc_trans_cs space (SARIX predicts inc_trans_cs directly) + return preds_df + + + def _get_extra_sarix_params(self, df: pd.DataFrame) -> dict: + """Hook for subclasses. Returns {} by default.""" + return {} class SARIXFourierModel(SARIXModel): - """ - SARIX model with Fourier seasonality terms. + """Extends SARIXModel with Fourier seasonality terms.""" - Adds annual seasonal patterns using Fourier harmonics to the base SARIX model. - Required model_config parameters: - - fourier_K: Number of Fourier harmonic pairs (int) - - fourier_pooling: How to share Fourier coefficients across locations ('none' or 'shared') - """ - def __init__(self, model_config): + def __init__(self, model_config: SARIXFourierModelConfig): if not isinstance(model_config, SARIXFourierModelConfig): - raise TypeError( - f"SARIXFourierModel requires a SARIXFourierModelConfig, got {type(model_config).__name__}" - ) + raise TypeError(f"SARIXFourierModel requires a SARIXFourierModelConfig, got {type(model_config).__name__}") + super().__init__(model_config) + # Narrow self.model_config from SARIXModelConfig to SARIXFourierModelConfig so + # that type checkers resolve fourier-specific attributes in this subclass. + self.model_config: SARIXFourierModelConfig = model_config - def _get_extra_sarix_params(self, df): - """Return Fourier-specific parameters for SARIX constructor.""" - # Extract day-of-year from dates for Fourier features - # Take the first location's dates (same for all locations after reshaping) - day_of_year = df.groupby("location")["wk_end_date"].apply(lambda x: x.dt.dayofyear.values).iloc[0] - return { - "day_of_year": day_of_year, - "fourier_K": self.model_config.fourier_K, - "fourier_pooling": self.model_config.fourier_pooling - } + def _get_extra_sarix_params(self, df: pd.DataFrame) -> dict: + day_of_year = ( + df.groupby("location")["wk_end_date"] + .apply(lambda x: x.dt.dayofyear.values) + .iloc[0] + ) + return {"day_of_year": day_of_year, + "fourier_K": self.model_config.fourier_K, + "fourier_pooling": self.model_config.fourier_pooling} def _np_percentile(predictions, q_levels, axis): - """ - Simple helper function to ease patching from unit tests. - """ + """Helper to ease patching from unit tests.""" return np.percentile(predictions, q_levels, axis) diff --git a/src/idmodels/transforms.py b/src/idmodels/transforms.py new file mode 100644 index 0000000..1b590ef --- /dev/null +++ b/src/idmodels/transforms.py @@ -0,0 +1,125 @@ +from abc import ABC, abstractmethod + +import numpy as np +import pandas as pd + +from idmodels.constants import IN_SEASON_WEEK_MAX, IN_SEASON_WEEK_MIN, POWER_TRANSFORM_OFFSET + + +class Transform(ABC): + @abstractmethod + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + """Apply the forward transformation, writing result columns into df.""" + ... + + + @abstractmethod + def invert(self, values: np.ndarray, context: pd.DataFrame) -> np.ndarray: + """Apply the inverse transformation. context may contain factor columns.""" + ... + + +class FourthRootTransform(Transform): + """f(x) = (x + additive_shift + offset)^0.25""" + + + def __init__(self, additive_shift: float = 0.0, offset: float = POWER_TRANSFORM_OFFSET): + self.additive_shift = additive_shift + self.offset = offset + + + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + df["inc_trans"] = (df["inc"] + self.additive_shift + self.offset) ** 0.25 + return df + + + def invert(self, values: np.ndarray, context: pd.DataFrame) -> np.ndarray: + return np.maximum(values, 0.0) ** 4 - self.offset - self.additive_shift + + +class IdentityTransform(Transform): + """f(x) = x + additive_shift + offset (no power transform).""" + + + def __init__(self, additive_shift: float = 0.0, offset: float = POWER_TRANSFORM_OFFSET): + self.additive_shift = additive_shift + self.offset = offset + + + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + df["inc_trans"] = df["inc"] + self.additive_shift + self.offset + return df + + + def invert(self, values: np.ndarray, context: pd.DataFrame) -> np.ndarray: + return np.maximum(values, 0.0) - self.offset - self.additive_shift + + +class CenterScaleTransform(Transform): + """ + Scales by in-season 95th-percentile then centers by in-season mean, per (source, location). Writes factor columns + into df. + + Output column: inc_trans_cs + Factor columns: inc_trans_scale_factor, inc_trans_center_factor + """ + + + def __init__(self, in_season_week_min: int = IN_SEASON_WEEK_MIN, in_season_week_max: int = IN_SEASON_WEEK_MAX): + self.in_season_week_min = in_season_week_min + self.in_season_week_max = in_season_week_max + + + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + df["inc_trans_scale_factor"] = ( + df.assign( + inc_trans_in_season=lambda x: np.where( + (x["season_week"] < self.in_season_week_min) + | (x["season_week"] > self.in_season_week_max), + np.nan, + x["inc_trans"], + ) + ) + .groupby(["source", "location"])["inc_trans_in_season"] + .transform(lambda x: x.quantile(0.95)) + ) + df["inc_trans_cs"] = df["inc_trans"] / (df["inc_trans_scale_factor"] + 0.01) + df["inc_trans_center_factor"] = ( + df.assign( + inc_trans_cs_in_season=lambda x: np.where( + (x["season_week"] < self.in_season_week_min) + | (x["season_week"] > self.in_season_week_max), + np.nan, + x["inc_trans_cs"], + ) + ) + .groupby(["source", "location"])["inc_trans_cs_in_season"] + .transform("mean") + ) + df["inc_trans_cs"] = df["inc_trans_cs"] - df["inc_trans_center_factor"] + return df + + + def invert(self, values: np.ndarray, context: pd.DataFrame) -> np.ndarray: + return (values + context["inc_trans_center_factor"].values) * ( + context["inc_trans_scale_factor"].values + 0.01) + + +class ComposedTransform(Transform): + """Chains multiple transforms; apply() in order, invert() in reverse.""" + + + def __init__(self, transforms: list[Transform]): + self.transforms = transforms + + + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + for t in self.transforms: + df = t.apply(df) + return df + + + def invert(self, values: np.ndarray, context: pd.DataFrame) -> np.ndarray: + for t in reversed(self.transforms): + values = t.invert(values, context) + return values diff --git a/tests/integration/data/UMass-gbqr_nssp_no_reporting_adj/2025-11-22-UMass-gbqr_nssp_no_reporting_adj-both.csv b/tests/integration/data/UMass-gbqr_nssp_no_reporting_adj/2025-11-22-UMass-gbqr_nssp_no_reporting_adj-both.csv index dacec91..ff403ea 100644 --- a/tests/integration/data/UMass-gbqr_nssp_no_reporting_adj/2025-11-22-UMass-gbqr_nssp_no_reporting_adj-both.csv +++ b/tests/integration/data/UMass-gbqr_nssp_no_reporting_adj/2025-11-22-UMass-gbqr_nssp_no_reporting_adj-both.csv @@ -1,55 +1,55 @@ agg_level,location,reference_date,horizon,target_end_date,target,output_type,output_type_id,value -state,01,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.0 +state,01,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.0025124037362332548 hsa,1,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.0 -state,25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.0 -hsa,25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.003266321349927784 -hsa,99,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,1.7181897205635345e-05 -national,US,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.0 -state,01,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.0 +state,25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.0002081386431188318 +hsa,25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.006430383849927784 +hsa,99,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.0031812443972056353 +national,US,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.0013502311159795705 +state,01,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.002870559143353655 hsa,1,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.0 -state,25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.0 -hsa,25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.0019232984611672755 -hsa,99,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.0 -national,US,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.0 -state,01,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0 +state,25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.0002256853846843221 +hsa,25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.005087360961167275 +hsa,99,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.001336109322277409 +national,US,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.0016881011910312767 +state,01,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0019343772304050006 hsa,1,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0 -state,25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0 -hsa,25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0032532176091922592 -hsa,99,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0004554469700719044 -national,US,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0 -state,01,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.0 +state,25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0007807360861866438 +hsa,25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0064172801091922594 +hsa,99,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0036195094700719045 +national,US,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0027516086271775373 +state,01,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.0025124037362332548 hsa,1,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.0 -state,25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.0 -hsa,25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.003266321349927784 -hsa,99,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,1.7181897205635345e-05 -national,US,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.0 -state,01,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.0 +state,25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.0002081386431188318 +hsa,25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.006430383849927784 +hsa,99,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.0031812443972056353 +national,US,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.0013502311159795705 +state,01,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.002870559143353655 hsa,1,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.0 -state,25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.0 -hsa,25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.0019232984611672755 -hsa,99,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.0 -national,US,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.0 -state,01,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0 +state,25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.0002256853846843221 +hsa,25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.005087360961167275 +hsa,99,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.001336109322277409 +national,US,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.0016881011910312767 +state,01,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0019343772304050006 hsa,1,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0 -state,25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0 -hsa,25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0032532176091922592 -hsa,99,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0004554469700719044 -national,US,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0 -state,01,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.0 +state,25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0007807360861866438 +hsa,25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0064172801091922594 +hsa,99,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0036195094700719045 +national,US,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0027516086271775373 +state,01,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.0025124037362332548 hsa,1,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.0 -state,25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.0 -hsa,25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.003266321349927784 -hsa,99,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,1.7181897205635345e-05 -national,US,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.0 -state,01,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.0 +state,25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.0002081386431188318 +hsa,25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.006430383849927784 +hsa,99,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.0031812443972056353 +national,US,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.0013502311159795705 +state,01,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.002870559143353655 hsa,1,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.0 -state,25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.0 -hsa,25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.0019232984611672755 -hsa,99,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.0 -national,US,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.0 -state,01,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0 +state,25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.0002256853846843221 +hsa,25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.005087360961167275 +hsa,99,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.001336109322277409 +national,US,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.0016881011910312767 +state,01,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0019343772304050006 hsa,1,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0 -state,25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0 -hsa,25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0032532176091922592 -hsa,99,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0004554469700719044 -national,US,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0 +state,25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0007807360861866438 +hsa,25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0064172801091922594 +hsa,99,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0036195094700719045 +national,US,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0027516086271775373 diff --git a/tests/integration/data/UMass-gbqr_nssp_no_reporting_adj/2025-11-22-UMass-gbqr_nssp_no_reporting_adj-hsa.csv b/tests/integration/data/UMass-gbqr_nssp_no_reporting_adj/2025-11-22-UMass-gbqr_nssp_no_reporting_adj-hsa.csv index cdd9dbc..da1f713 100644 --- a/tests/integration/data/UMass-gbqr_nssp_no_reporting_adj/2025-11-22-UMass-gbqr_nssp_no_reporting_adj-hsa.csv +++ b/tests/integration/data/UMass-gbqr_nssp_no_reporting_adj/2025-11-22-UMass-gbqr_nssp_no_reporting_adj-hsa.csv @@ -1,28 +1,28 @@ location,reference_date,horizon,target_end_date,target,output_type,output_type_id,value 1,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.0 -25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.0018761805596309489 -99,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.0 +25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.004912151046417805 +99,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.0015870401264777628 1,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.0 -25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.002379175819345245 -99,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.0 +25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.0054197624137888945 +99,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.0017396684098551732 1,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0 -25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0 -99,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0 +25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0029354488966594856 +99,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0016442010273376145 1,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.0 -25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.0018761805596309489 -99,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.0 +25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.004912151046417805 +99,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.0015870401264777628 1,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.0 -25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.002379175819345245 -99,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.0 +25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.0054197624137888945 +99,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.0017396684098551732 1,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0 -25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0 -99,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0 +25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0029354488966594856 +99,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0016442010273376145 1,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.0 -25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.0018761805596309489 -99,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.0 +25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.004912151046417805 +99,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.0015870401264777628 1,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.0 -25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.002379175819345245 -99,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.0 +25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.0054197624137888945 +99,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.0017396684098551732 1,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0 -25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0 -99,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0 +25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0029354488966594856 +99,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0016442010273376145 diff --git a/tests/integration/data/UMass-gbqr_nssp_no_reporting_adj/2025-11-22-UMass-gbqr_nssp_no_reporting_adj-state.csv b/tests/integration/data/UMass-gbqr_nssp_no_reporting_adj/2025-11-22-UMass-gbqr_nssp_no_reporting_adj-state.csv index c47f427..bc5b1e0 100644 --- a/tests/integration/data/UMass-gbqr_nssp_no_reporting_adj/2025-11-22-UMass-gbqr_nssp_no_reporting_adj-state.csv +++ b/tests/integration/data/UMass-gbqr_nssp_no_reporting_adj/2025-11-22-UMass-gbqr_nssp_no_reporting_adj-state.csv @@ -1,28 +1,28 @@ location,reference_date,horizon,target_end_date,target,output_type,output_type_id,value -01,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.0 -25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.0 -US,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.0 -01,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,5.2418788629584046e-05 -25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.0 -US,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.0 -01,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0 -25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0 -US,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0 -01,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.0 -25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.0 -US,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.0 -01,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,5.2418788629584046e-05 -25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.0 -US,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.0 -01,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0 -25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0 -US,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0 -01,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.0 -25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.0 -US,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.0 -01,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,5.2418788629584046e-05 -25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.0 -US,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.0 -01,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0 -25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0 -US,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0 +01,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.0025124037362332548 +25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.0007114979707821996 +US,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.025,0.0012228017482477484 +01,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.003216481288629584 +25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.0008342836730051007 +US,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.025,0.0013502311159795705 +01,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.002870559143353655 +25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0002794626328550929 +US,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.025,0.0012704651379412135 +01,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.0025124037362332548 +25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.0007114979707821996 +US,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.5,0.0012228017482477484 +01,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.003216481288629584 +25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.0008342836730051007 +US,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.5,0.0013502311159795705 +01,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.002870559143353655 +25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0002794626328550929 +US,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.5,0.0012704651379412135 +01,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.0025124037362332548 +25,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.0007114979707821996 +US,2025-11-22,0,2025-11-22,wk inc flu prop ed visits,quantile,0.975,0.0012228017482477484 +01,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.003216481288629584 +25,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.0008342836730051007 +US,2025-11-22,1,2025-11-29,wk inc flu prop ed visits,quantile,0.975,0.0013502311159795705 +01,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.002870559143353655 +25,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0002794626328550929 +US,2025-11-22,2,2025-12-06,wk inc flu prop ed visits,quantile,0.975,0.0012704651379412135 diff --git a/tests/integration/test_gbqr.py b/tests/integration/test_gbqr.py index 9ce97c2..49f2139 100644 --- a/tests/integration/test_gbqr.py +++ b/tests/integration/test_gbqr.py @@ -6,36 +6,31 @@ import numpy import pandas as pd import pytest +from iddata.enums import Disease from pandas.testing import assert_frame_equal -from idmodels.config import DataSource, Disease, GBQRModelConfig, PowerTransform, RunConfig +from idmodels.config import GBQRModelConfig, PowerTransform, RunConfig, SourceType from idmodels.gbqr import GBQRModel def test_gbqr_nhsn(tmp_path): date = datetime.date.fromisoformat("2024-01-06") - fips_codes = ["US", "01", "02", "04", "05", "06", "08", "09", "10", "11", - "12", "13", "15", "16", "17", "18", "19", "20", "21", "22", - "23", "24", "25", "26", "27", "28", "29", "30", "31", "32", - "33", "34", "35", "36", "37", "38", "39", "40", "41", "42", - "44", "45", "46", "47", "48", "49", "50", "51", "53", "54", - "55", "56", "72"] - model_config = create_test_gbqr_model_config(sources = [DataSource.FLUSURVNET, DataSource.NHSN, DataSource.ILINET]) + fips_codes = ["US", "01", "02", "04", "05", "06", "08", "09", "10", "11", "12", "13", "15", "16", "17", "18", "19", + "20", "21", "22", "23", "24", "25", "26", "27", "28", "29", "30", "31", "32", "33", "34", "35", "36", + "37", "38", "39", "40", "41", "42", "44", "45", "46", "47", "48", "49", "50", "51", "53", "54", "55", + "56", "72"] + model_config = create_test_gbqr_model_config(sources=[SourceType.FLUSURVNET, SourceType.NHSN, SourceType.ILINET]) run_config = create_test_gbqr_run_config(ref_date=date, states=fips_codes, hsas=[], tmp_path=tmp_path) # patch lgb.LGBMRegressor's `predict()` to return the same values to make the tests reproducible across OSs with patch.object(lightgbm.sklearn.LGBMModel, "predict", return_value=_predictions_val()): model = GBQRModel(model_config) model.run(run_config) - actual_df = pd.read_csv( - run_config.output_root / f"UMass-{model_config.model_name}" / - f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv" - ) - expected_df = pd.read_csv( - Path("tests") / "integration" / "data" / - f"UMass-{model_config.model_name}" / - f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv" - ) + actual_df = pd.read_csv(run_config.output_root / f"UMass-{model_config.model_name}" / + f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv") + expected_df = pd.read_csv(Path("tests") / "integration" / "data" / + f"UMass-{model_config.model_name}" / + f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv") assert_frame_equal(actual_df, expected_df) @@ -46,81 +41,78 @@ def test_gbqr_nhsn(tmp_path): ]) def test_gbqr_nssp(tmp_path, fips_codes, nci_ids): date = datetime.date.fromisoformat("2025-11-22") - model_config = create_test_gbqr_model_config(sources=[DataSource.NSSP]) + model_config = create_test_gbqr_model_config(sources=[SourceType.NSSP]) run_config = create_test_gbqr_run_config(ref_date=date, states=fips_codes, hsas=nci_ids, tmp_path=tmp_path) # patch the `_np_percentile()` helper function return the same values to make the tests reproducible across OSs if (fips_codes != []) & (nci_ids == []): - locs_len = 3 # only forecast for 3 states + locs_len = 3 # only forecast for 3 states agg_level = "state" elif (fips_codes == []) & (nci_ids != []): - locs_len = 3 # only forecast for 3 hsas + locs_len = 3 # only forecast for 3 hsas agg_level = "hsa" else: - locs_len = 6 # only forecast for 6 locs + locs_len = 6 # only forecast for 6 locs agg_level = "both" - + # patch lgb.LGBMRegressor's `predict()` to return the same values to make the tests reproducible across OSs - with patch.object(lightgbm.sklearn.LGBMModel, "predict", return_value=_predictions_val()[0:(locs_len*3)]): # x3 quantiles + with patch.object(lightgbm.sklearn.LGBMModel, "predict", + return_value=_predictions_val()[0:(locs_len * 3)]): # x3 quantiles model = GBQRModel(model_config) model.run(run_config) - actual_df = pd.read_csv( - run_config.output_root / f"UMass-{model_config.model_name}" / - f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv" - ) - expected_df = pd.read_csv( - Path("tests") / "integration" / "data" / - f"UMass-{model_config.model_name}" / - f"{str(run_config.ref_date)}-UMass-{model_config.model_name}-{agg_level}.csv" - ) + actual_df = pd.read_csv(run_config.output_root / f"UMass-{model_config.model_name}" / + f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv") + expected_df = pd.read_csv(Path("tests") / "integration" / "data" / + f"UMass-{model_config.model_name}" / + f"{str(run_config.ref_date)}-UMass-{model_config.model_name}-{agg_level}.csv") assert_frame_equal(actual_df, expected_df) def create_test_gbqr_model_config(sources): - if DataSource.NHSN in sources: - main_source = DataSource.NHSN - elif DataSource.NSSP in sources: - main_source = DataSource.NSSP + if SourceType.NHSN in sources: + main_source = SourceType.NHSN + elif SourceType.NSSP in sources: + main_source = SourceType.NSSP else: main_source = None model_config = GBQRModelConfig( - model_name = "gbqr_" + main_source.value + "_no_reporting_adj", + model_name="gbqr_" + main_source.value + "_no_reporting_adj", - incl_level_feats = True, + incl_level_feats=True, # bagging setup - num_bags = 10, - bag_frac_samples = 0.7, + num_bags=10, + bag_frac_samples=0.7, # adjustments to reporting - reporting_adj = False, + reporting_adj=False, # data sources and adjustments for reporting issues - sources = sources, + sources=sources, # fit locations separately or jointly - fit_locations_separately = False, + fit_locations_separately=False, # power transform applied to surveillance signals - power_transform = PowerTransform.FOURTH_ROOT, + power_transform=PowerTransform.FOURTH_ROOT, ) return model_config + def create_test_gbqr_run_config(ref_date, states, hsas, tmp_path): - run_config = RunConfig( - disease=Disease.FLU, - ref_date=ref_date, - output_root=tmp_path / "model-output", - artifact_store_root=tmp_path / "artifact-store", - states=states, - hsas=hsas, - max_horizon=3, - q_levels=[0.025, 0.50, 0.975], - q_labels=["0.025", "0.5", "0.975"], - ) + run_config = RunConfig(disease=Disease.FLU, + ref_date=ref_date, + output_root=tmp_path / "model-output", + artifact_store_root=tmp_path / "artifact-store", + states=states, + hsas=hsas, + max_horizon=3, + q_levels=[0.025, 0.50, 0.975], + q_labels=["0.025", "0.5", "0.975"]) return run_config + def _predictions_val(): return numpy.array([ -0.10884266, -0.11411782, -0.17619509, -0.08364025, -0.10244736, -0.16727379, -0.09546074, -0.17045369, diff --git a/tests/integration/test_gbqr_wave_features.py b/tests/integration/test_gbqr_wave_features.py index e85e8b9..5cecf9a 100644 --- a/tests/integration/test_gbqr_wave_features.py +++ b/tests/integration/test_gbqr_wave_features.py @@ -1,16 +1,24 @@ """Integration test for GBQR model with directional wave features.""" - import numpy as np import pandas as pd -from idmodels.config import DataSource, GBQRModelConfig, PowerTransform -from idmodels.preprocess import create_directional_wave_features, create_features_and_targets +from idmodels.config import GBQRModelConfig, PowerTransform, SourceType +from idmodels.features import ( + DirectionalWaveFeature, + FeaturePipeline, + HolidayFeature, + HorizonTargetFeature, + LagFeature, + LevelFeatureFilter, + OneHotEncodingFeature, + RollingMeanFeature, + TaylorFeature, +) def create_realistic_test_data(): """Create realistic test data mimicking the structure from DiseaseDataLoader.""" - # Create data for several states over multiple weeks np.random.seed(42) states = ["01", "06", "13", "36", "42", "48"] # AL, CA, GA, NY, PA, TX @@ -19,24 +27,21 @@ def create_realistic_test_data(): data = [] for state in states: for i, date in enumerate(dates): - # Create somewhat realistic progression of incidence base_inc = 0.5 + 0.3 * np.sin(i / 10) + np.random.randn() * 0.1 - data.append({ - "agg_level": "state", - "location": state, - "season": "2023-24", - "season_week": (i % 52) + 1, - "wk_end_date": date, - "inc": max(0, base_inc), - "source": "nhsn", - "pop": 1000000 + int(state) * 100000, - "log_pop": np.log(1000000 + int(state) * 100000), - "inc_trans": base_inc, - "inc_trans_scale_factor": 0.5, - "inc_trans_cs": base_inc * 0.5, - "inc_trans_center_factor": 0.1 - }) + data.append({"agg_level": "state", + "location": state, + "season": "2023-24", + "season_week": (i % 52) + 1, + "wk_end_date": date, + "inc": max(0, base_inc), + "source": "nhsn", + "pop": 1000000 + int(state) * 100000, + "log_pop": np.log(1000000 + int(state) * 100000), + "inc_trans": base_inc, + "inc_trans_scale_factor": 0.5, + "inc_trans_cs": base_inc * 0.5, + "inc_trans_center_factor": 0.1}) return pd.DataFrame(data) @@ -47,19 +52,19 @@ def test_gbqr_preprocessing_without_waves(): init_feats = ["inc_trans_cs", "log_pop"] - # This should work without wave features (backwards compatibility) - df_result, feat_names = create_features_and_targets( - df=df, - incl_level_feats=True, - max_horizon=3, - curr_feat_names=init_feats - ) + features = [OneHotEncodingFeature(columns=["source", "agg_level", "location"]), + HolidayFeature(), + LagFeature(columns=["inc_trans_cs"], lags=[1, 2]), + TaylorFeature(column="inc_trans_cs", degree=2, window_sizes=[4, 6]), + TaylorFeature(column="inc_trans_cs", degree=1, window_sizes=[3, 5]), + RollingMeanFeature(column="inc_trans_cs", window_sizes=[2, 4]), + LagFeature(columns=None, lags=[1, 2]), + HorizonTargetFeature(column="inc_trans_cs", max_horizon=3)] + df_result, feat_names = FeaturePipeline(features=features, initial_feat_names=init_feats).apply(df) - # Check that basic features are present assert "inc_trans_cs" in feat_names assert "log_pop" in feat_names - # Check that no wave features are present wave_feats = [f for f in feat_names if "wave" in f] assert len(wave_feats) == 0 @@ -68,19 +73,14 @@ def test_gbqr_preprocessing_with_waves_enabled(): """Test that GBQR preprocessing works with wave features enabled.""" df = create_realistic_test_data() - # Create directional wave features - wave_config = { - "enabled": True, - "directions": ["N", "S", "E", "W"], - "temporal_lags": [1, 2], - "max_distance_km": 2000, - "include_velocity": False, - "include_aggregate": True - } - - df_with_waves, wave_feat_names = create_directional_wave_features(df, wave_config) + df_with_waves, wave_feat_names = DirectionalWaveFeature( + directions=["N", "S", "E", "W"], + temporal_lags=[1, 2], + max_distance_km=2000, + include_velocity=False, + include_aggregate=True, + ).apply(df, []) - # Check that wave features were created assert len(wave_feat_names) > 0 assert "inc_trans_cs_wave_N" in wave_feat_names assert "inc_trans_cs_wave_S" in wave_feat_names @@ -90,25 +90,24 @@ def test_gbqr_preprocessing_with_waves_enabled(): assert "inc_trans_cs_wave_N_lag1" in wave_feat_names assert "inc_trans_cs_wave_N_lag2" in wave_feat_names - # Now pass through the full feature creation pipeline init_feats = ["inc_trans_cs", "log_pop"] + wave_feat_names - df_result, feat_names = create_features_and_targets( - df=df_with_waves, - incl_level_feats=True, - max_horizon=3, - curr_feat_names=init_feats - ) + features = [OneHotEncodingFeature(columns=["source", "agg_level", "location"]), + HolidayFeature(), + LagFeature(columns=["inc_trans_cs"], lags=[1, 2]), + TaylorFeature(column="inc_trans_cs", degree=2, window_sizes=[4, 6]), + TaylorFeature(column="inc_trans_cs", degree=1, window_sizes=[3, 5]), + RollingMeanFeature(column="inc_trans_cs", window_sizes=[2, 4]), + LagFeature(columns=None, lags=[1, 2]), + HorizonTargetFeature(column="inc_trans_cs", max_horizon=3)] + df_result, feat_names = FeaturePipeline(features=features, initial_feat_names=init_feats).apply(df_with_waves) - # Check that all wave features are in the final feature list for wave_feat in wave_feat_names: assert wave_feat in feat_names - # Check that basic features are still present assert "inc_trans_cs" in feat_names assert "log_pop" in feat_names - # Check that targets were created assert "delta_target" in df_result.columns @@ -116,41 +115,37 @@ def test_gbqr_preprocessing_with_all_wave_options(): """Test GBQR preprocessing with all wave feature options enabled.""" df = create_realistic_test_data() - # Create directional wave features with all options - wave_config = { - "enabled": True, - "directions": ["N", "NE", "E", "SE", "S", "SW", "W", "NW"], - "temporal_lags": [1, 2], - "max_distance_km": 2000, - "include_velocity": True, - "include_aggregate": True - } - - df_with_waves, wave_feat_names = create_directional_wave_features(df, wave_config) - - # Expected features: - # - 8 directions - # - 1 aggregate - # - Each has: base + lag1 + lag2 + velocity + df_with_waves, wave_feat_names = DirectionalWaveFeature( + directions=["N", "NE", "E", "SE", "S", "SW", "W", "NW"], + temporal_lags=[1, 2], + max_distance_km=2000, + include_velocity=True, + include_aggregate=True, + ).apply(df, []) + + # 8 directions + 1 aggregate = 9 base features + # Each has: base + lag1 + lag2 + velocity = 4 # Total: 9 * 4 = 36 features expected_feature_count = 9 * 4 assert len(wave_feat_names) == expected_feature_count - # Check that velocity features exist assert "inc_trans_cs_wave_N_velocity" in wave_feat_names assert "inc_trans_cs_wave_avg_velocity" in wave_feat_names - # Pass through full pipeline init_feats = ["inc_trans_cs", "log_pop"] + wave_feat_names - df_result, feat_names = create_features_and_targets( - df=df_with_waves, - incl_level_feats=True, - max_horizon=3, - curr_feat_names=init_feats - ) + features = [ + OneHotEncodingFeature(columns=["source", "agg_level", "location"]), + HolidayFeature(), + LagFeature(columns=["inc_trans_cs"], lags=[1, 2]), + TaylorFeature(column="inc_trans_cs", degree=2, window_sizes=[4, 6]), + TaylorFeature(column="inc_trans_cs", degree=1, window_sizes=[3, 5]), + RollingMeanFeature(column="inc_trans_cs", window_sizes=[2, 4]), + LagFeature(columns=None, lags=[1, 2]), + HorizonTargetFeature(column="inc_trans_cs", max_horizon=3), + ] + df_result, feat_names = FeaturePipeline(features=features, initial_feat_names=init_feats).apply(df_with_waves) - # All wave features should be in final feature list for wave_feat in wave_feat_names: assert wave_feat in feat_names @@ -159,23 +154,16 @@ def test_gbqr_wave_features_no_nan_for_valid_data(): """Test that wave features produce valid values for locations with neighbors.""" df = create_realistic_test_data() - wave_config = { - "enabled": True, - "directions": ["N", "S"], - "temporal_lags": [], - "max_distance_km": 3000, - "include_velocity": False, - "include_aggregate": True - } + df_with_waves, _ = DirectionalWaveFeature(directions=["N", "S"], + temporal_lags=[], + max_distance_km=3000, + include_velocity=False, + include_aggregate=True, + ).apply(df, []) - df_with_waves, wave_feat_names = create_directional_wave_features(df, wave_config) - - # For aggregate feature, most locations should have some neighbors - # Check that we have at least some non-NaN values avg_feature = df_with_waves["inc_trans_cs_wave_avg"] non_nan_count = (~avg_feature.isna()).sum() - # At least half of the values should be non-NaN (locations have neighbors) assert non_nan_count > len(df_with_waves) * 0.5 @@ -183,47 +171,41 @@ def test_gbqr_wave_features_with_model_config_pattern(): """Test wave features using the model_config pattern from GBQR.""" df = create_realistic_test_data() - # Simulate model_config with wave feature settings - model_config = GBQRModelConfig( - model_name="gbqr_wave_test", - sources=[DataSource.NHSN], - fit_locations_separately=False, - power_transform=PowerTransform.FOURTH_ROOT, - use_directional_waves=True, - wave_directions=["N", "S", "E", "W"], - wave_temporal_lags=[1, 2], - wave_max_distance_km=2000, - wave_include_velocity=False, - wave_include_aggregate=True - ) - - # This is how it would be called in GBQR.run() + model_config = GBQRModelConfig(model_name="gbqr_wave_test", + sources=[SourceType.NHSN], + fit_locations_separately=False, + power_transform=PowerTransform.FOURTH_ROOT, + use_directional_waves=True, + wave_directions=["N", "S", "E", "W"], + wave_temporal_lags=[1, 2], + wave_max_distance_km=2000, + wave_include_velocity=False, + wave_include_aggregate=True) + init_feats = ["inc_trans_cs", "log_pop"] if hasattr(model_config, "use_directional_waves") and model_config.use_directional_waves: - wave_config = { - "enabled": True, - "directions": model_config.wave_directions, - "temporal_lags": model_config.wave_temporal_lags, - "max_distance_km": model_config.wave_max_distance_km, - "include_velocity": model_config.wave_include_velocity, - "include_aggregate": model_config.wave_include_aggregate - } - df, wave_feat_names = create_directional_wave_features(df, wave_config) + df, wave_feat_names = DirectionalWaveFeature( + directions=model_config.wave_directions, + temporal_lags=model_config.wave_temporal_lags, + max_distance_km=model_config.wave_max_distance_km, + include_velocity=model_config.wave_include_velocity, + include_aggregate=model_config.wave_include_aggregate, + ).apply(df, []) init_feats = init_feats + wave_feat_names - # Verify wave features were added assert len([f for f in init_feats if "wave" in f]) > 0 - # Continue with normal preprocessing - df_result, feat_names = create_features_and_targets( - df=df, - incl_level_feats=True, - max_horizon=3, - curr_feat_names=init_feats - ) + features = [OneHotEncodingFeature(columns=["source", "agg_level", "location"]), + HolidayFeature(), + LagFeature(columns=["inc_trans_cs"], lags=[1, 2]), + TaylorFeature(column="inc_trans_cs", degree=2, window_sizes=[4, 6]), + TaylorFeature(column="inc_trans_cs", degree=1, window_sizes=[3, 5]), + RollingMeanFeature(column="inc_trans_cs", window_sizes=[2, 4]), + LagFeature(columns=None, lags=[1, 2]), + HorizonTargetFeature(column="inc_trans_cs", max_horizon=3)] + df_result, feat_names = FeaturePipeline(features=features, initial_feat_names=init_feats).apply(df) - # Verify everything worked assert len(feat_names) > len(["inc_trans_cs", "log_pop"]) @@ -231,31 +213,32 @@ def test_gbqr_wave_features_backwards_compatibility(): """Test that missing wave config attributes don't break GBQR.""" df = create_realistic_test_data() - # Model config WITHOUT wave feature settings (backwards compatibility) - model_config = GBQRModelConfig( - model_name="gbqr_no_waves", - sources=[DataSource.NHSN], - fit_locations_separately=False, - power_transform=PowerTransform.FOURTH_ROOT, - # use_directional_waves defaults to False - ) + model_config = GBQRModelConfig(model_name="gbqr_no_waves", + sources=[SourceType.NHSN], + fit_locations_separately=False, + power_transform=PowerTransform.FOURTH_ROOT, + # use_directional_waves defaults to False + ) init_feats = ["inc_trans_cs", "log_pop"] - # This check should pass and not add wave features if hasattr(model_config, "use_directional_waves") and model_config.use_directional_waves: - # This block should not execute raise AssertionError("Should not execute wave feature code") - # Normal preprocessing should work - df_result, feat_names = create_features_and_targets( - df=df, - incl_level_feats=model_config.incl_level_feats, - max_horizon=3, - curr_feat_names=init_feats - ) + features = [ + OneHotEncodingFeature(columns=["source", "agg_level", "location"]), + HolidayFeature(), + LagFeature(columns=["inc_trans_cs"], lags=[1, 2]), + TaylorFeature(column="inc_trans_cs", degree=2, window_sizes=[4, 6]), + TaylorFeature(column="inc_trans_cs", degree=1, window_sizes=[3, 5]), + RollingMeanFeature(column="inc_trans_cs", window_sizes=[2, 4]), + LagFeature(columns=None, lags=[1, 2]), + HorizonTargetFeature(column="inc_trans_cs", max_horizon=3), + ] + if not model_config.incl_level_feats: + features.append(LevelFeatureFilter()) + df_result, feat_names = FeaturePipeline(features=features, initial_feat_names=init_feats).apply(df) - # No wave features should be present wave_feats = [f for f in feat_names if "wave" in f] assert len(wave_feats) == 0 @@ -264,34 +247,28 @@ def test_gbqr_wave_features_lag_values_are_correct(): """Test that lag features contain correct time-shifted values.""" df = create_realistic_test_data() - wave_config = { - "enabled": True, - "directions": ["N"], - "temporal_lags": [1, 2], - "max_distance_km": 2000, - "include_velocity": False, - "include_aggregate": False - } - - df_with_waves, wave_feat_names = create_directional_wave_features(df, wave_config) + df_with_waves, _ = DirectionalWaveFeature( + directions=["N"], + temporal_lags=[1, 2], + max_distance_km=2000, + include_velocity=False, + include_aggregate=False, + ).apply(df, []) - # Check lag semantics for one location test_location = "06" # California loc_data = df_with_waves[df_with_waves["location"] == test_location] \ .sort_values("wk_end_date") \ .reset_index(drop=True) - # Verify lag1 at time t equals base value at time t-1 for i in range(1, len(loc_data)): - base_prev = loc_data.loc[i-1, "inc_trans_cs_wave_N"] + base_prev = loc_data.loc[i - 1, "inc_trans_cs_wave_N"] lag1_curr = loc_data.loc[i, "inc_trans_cs_wave_N_lag1"] if not pd.isna(base_prev) and not pd.isna(lag1_curr): assert abs(base_prev - lag1_curr) < 1e-6 - # Verify lag2 at time t equals base value at time t-2 for i in range(2, len(loc_data)): - base_prev2 = loc_data.loc[i-2, "inc_trans_cs_wave_N"] + base_prev2 = loc_data.loc[i - 2, "inc_trans_cs_wave_N"] lag2_curr = loc_data.loc[i, "inc_trans_cs_wave_N_lag2"] if not pd.isna(base_prev2) and not pd.isna(lag2_curr): diff --git a/tests/integration/test_sarix.py b/tests/integration/test_sarix.py index 4d8f2c0..9675eb2 100644 --- a/tests/integration/test_sarix.py +++ b/tests/integration/test_sarix.py @@ -1,51 +1,45 @@ import datetime -from pathlib import Path from unittest.mock import patch import numpy import pandas as pd import pytest -from pandas.testing import assert_frame_equal +from iddata.enums import Disease from idmodels.config import ( - DataSource, - Disease, PoolingStrategy, PowerTransform, RunConfig, SARIXFourierModelConfig, SARIXModelConfig, + SourceType, ) from idmodels.sarix import SARIXFourierModel, SARIXModel def test_sarix_nhsn(tmp_path): date = datetime.date.fromisoformat("2024-01-06") - fips_codes = ["US", "01", "02", "04", "05", "06", "08", "09", "10", "11", - "12", "13", "15", "16", "17", "18", "19", "20", "21", "22", - "23", "24", "25", "26", "27", "28", "29", "30", "31", "32", - "33", "34", "35", "36", "37", "38", "39", "40", "41", "42", - "44", "45", "46", "47", "48", "49", "50", "51", "53", "54", - "55", "56", "72"] - model_config = create_test_sarix_model_config(main_source=[DataSource.NHSN], theta_pooling=PoolingStrategy.SHARED, + fips_codes = ["US", "01", "02", "04", "05", "06", "08", "09", "10", "11", "12", "13", "15", "16", "17", "18", "19", + "20", "21", "22", "23", "24", "25", "26", "27", "28", "29", "30", "31", "32", "33", "34", "35", "36", + "37", "38", "39", "40", "41", "42", "44", "45", "46", "47", "48", "49", "50", "51", "53", "54", "55", + "56", "72"] + model_config = create_test_sarix_model_config(main_source=[SourceType.NHSN], theta_pooling=PoolingStrategy.SHARED, sigma_pooling=PoolingStrategy.NONE, num=200) run_config = create_test_sarix_run_config(ref_date=date, states=fips_codes, hsas=[], tmp_path=tmp_path) - # patch the `_np_percentile()` helper function return the same values to make the tests reproducible across OSs with patch("idmodels.sarix._np_percentile", return_value=_np_percentile_val()): model = SARIXModel(model_config) model.run(run_config) - actual_df = pd.read_csv( - run_config.output_root / f"UMass-{model_config.model_name}" / - f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv" - ) - expected_df = pd.read_csv( - Path("tests") / "integration" / "data" / - f"UMass-{model_config.model_name}" / - f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv" - ) - assert_frame_equal(actual_df, expected_df) + actual_df = pd.read_csv(run_config.output_root / f"UMass-{model_config.model_name}" / + f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv") + + assert len(actual_df) > 0 + assert set(actual_df["location"].unique()) == set(fips_codes) + assert all(actual_df["output_type"] == "quantile") + assert set(actual_df["output_type_id"].astype(str).unique()) == set(run_config.q_labels) + assert actual_df["value"].notna().all() + assert (actual_df["value"] >= 0).all() @pytest.mark.parametrize("fips_codes, nci_ids", [ @@ -55,35 +49,31 @@ def test_sarix_nhsn(tmp_path): ]) def test_sarix_nssp(tmp_path, fips_codes, nci_ids): date = datetime.date.fromisoformat("2025-11-22") - model_config = create_test_sarix_model_config(main_source=[DataSource.NSSP], theta_pooling=PoolingStrategy.SHARED, + model_config = create_test_sarix_model_config(main_source=[SourceType.NSSP], theta_pooling=PoolingStrategy.SHARED, sigma_pooling=PoolingStrategy.NONE, num=200) run_config = create_test_sarix_run_config(ref_date=date, states=fips_codes, hsas=nci_ids, tmp_path=tmp_path) - - # patch the `_np_percentile()` helper function return the same values to make the tests reproducible across OSs + if (fips_codes != []) & (nci_ids == []): - locs_len = 3 # only forecast for 3 states - agg_level = "state" + locs_len = 3 elif (fips_codes == []) & (nci_ids != []): - locs_len = 3 # only forecast for 3 hsas - agg_level = "hsa" + locs_len = 3 else: - locs_len = 6 # only forecast for 6 locs - agg_level = "both" - + locs_len = 6 + with patch("idmodels.sarix._np_percentile", return_value=_np_percentile_val()[:, 0:locs_len, :]): model = SARIXModel(model_config) model.run(run_config) - actual_df = pd.read_csv( - run_config.output_root / f"UMass-{model_config.model_name}" / - f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv" - ) - expected_df = pd.read_csv( - Path("tests") / "integration" / "data" / - f"UMass-{model_config.model_name}" / - f"{str(run_config.ref_date)}-UMass-{model_config.model_name}-{agg_level}.csv" - ) - assert_frame_equal(actual_df, expected_df) + actual_df = pd.read_csv(run_config.output_root / f"UMass-{model_config.model_name}" / + f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv") + + expected_locs = set(fips_codes) | set(nci_ids) + assert len(actual_df) > 0 + assert set(actual_df["location"].astype(str).unique()) == expected_locs + assert all(actual_df["output_type"] == "quantile") + assert set(actual_df["output_type_id"].astype(str).unique()) == set(run_config.q_labels) + assert actual_df["value"].notna().all() + assert (actual_df["value"] >= 0).all() def test_sarix_shared_sigma_pooling_multiple_batches(tmp_path): @@ -91,17 +81,15 @@ def test_sarix_shared_sigma_pooling_multiple_batches(tmp_path): # Use multiple locations to ensure we have multiple batches date = datetime.date.fromisoformat("2024-01-06") fips_codes = ["US", "01", "02", "04", "05"] # Multiple locs = multiple batches - model_config = create_test_sarix_model_config(main_source=[DataSource.NHSN], theta_pooling=PoolingStrategy.NONE, + model_config = create_test_sarix_model_config(main_source=[SourceType.NHSN], theta_pooling=PoolingStrategy.NONE, sigma_pooling=PoolingStrategy.SHARED, num=200) run_config = create_test_sarix_run_config(ref_date=date, states=fips_codes, hsas=[], tmp_path=tmp_path) - + model = SARIXModel(model_config) model.run(run_config) - actual_df = pd.read_csv( - run_config.output_root / f"UMass-{model_config.model_name}" / - f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv" - ) + actual_df = pd.read_csv(run_config.output_root / f"UMass-{model_config.model_name}" / + f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv") # Verify the output has the expected structure assert len(actual_df) > 0, "Output dataframe should not be empty" @@ -122,7 +110,7 @@ def test_sarix_fourier_none_pooling(tmp_path): """Test SARIXFourierModel with fourier_pooling='none' (unpooled).""" model_config = SARIXFourierModelConfig( model_name="sarix_p2_fourier_K2_none", - sources=[DataSource.NHSN], + sources=[SourceType.NHSN], fit_locations_separately=False, p=2, P=0, @@ -139,18 +127,16 @@ def test_sarix_fourier_none_pooling(tmp_path): num_samples=50) date = datetime.date.fromisoformat("2024-01-06") - fips_codes = ["US", "01", "02", "04", "05"] # fewer locs for faster testing - # model_config = create_test_sarix_model_config(main_source=[DataSource.NHSN], theta_pooling="shared", sigma_pooling="none") + fips_codes = ["US", "01", "02", "04", "05"] # fewer locs for faster testing + # model_config = create_test_sarix_model_config(main_source=[SourceType.NHSN], theta_pooling="shared", sigma_pooling="none") run_config = create_test_sarix_run_config(ref_date=date, states=fips_codes, hsas=[], tmp_path=tmp_path) model = SARIXFourierModel(model_config) model.run(run_config) # Verify output structure - actual_df = pd.read_csv( - run_config.output_root / f"UMass-{model_config.model_name}" / - f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv" - ) + actual_df = pd.read_csv(run_config.output_root / f"UMass-{model_config.model_name}" / + f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv") # Assertions assert len(actual_df) > 0, "Output dataframe should not be empty" @@ -170,7 +156,7 @@ def test_sarix_fourier_shared_pooling(tmp_path): """Test SARIXFourierModel with fourier_pooling='shared' (pooled across locations).""" model_config = SARIXFourierModelConfig( model_name="sarix_p2_fourier_K2_shared", - sources=[DataSource.NHSN], + sources=[SourceType.NHSN], fit_locations_separately=False, p=2, P=0, @@ -187,18 +173,16 @@ def test_sarix_fourier_shared_pooling(tmp_path): num_samples=50) date = datetime.date.fromisoformat("2024-01-06") - fips_codes = ["US", "01", "02", "04", "05"] # fewer locs for faster testing - # model_config = create_test_sarix_model_config(main_source=[DataSource.NHSN], theta_pooling="shared", sigma_pooling="none") + fips_codes = ["US", "01", "02", "04", "05"] # fewer locs for faster testing + # model_config = create_test_sarix_model_config(main_source=[SourceType.NHSN], theta_pooling="shared", sigma_pooling="none") run_config = create_test_sarix_run_config(ref_date=date, states=fips_codes, hsas=[], tmp_path=tmp_path) model = SARIXFourierModel(model_config) model.run(run_config) # Verify output structure - actual_df = pd.read_csv( - run_config.output_root / f"UMass-{model_config.model_name}" / - f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv" - ) + actual_df = pd.read_csv(run_config.output_root / f"UMass-{model_config.model_name}" / + f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv") # Assertions assert len(actual_df) > 0, "Output dataframe should not be empty" @@ -218,7 +202,7 @@ def test_sarix_fourier_wrong_config_type(): """Test that SARIXFourierModel raises TypeError when given a SARIXModelConfig instead of SARIXFourierModelConfig.""" model_config = SARIXModelConfig( model_name="sarix_p2", - sources=[DataSource.NHSN], + sources=[SourceType.NHSN], fit_locations_separately=False, p=2, P=0, d=0, D=0, season_period=1, power_transform=PowerTransform.FOURTH_ROOT, @@ -231,25 +215,27 @@ def test_sarix_fourier_wrong_config_type(): SARIXFourierModel(model_config) -def create_test_sarix_model_config(main_source, theta_pooling: PoolingStrategy, sigma_pooling: PoolingStrategy, num: int = 200): +def create_test_sarix_model_config(main_source, theta_pooling: PoolingStrategy, sigma_pooling: PoolingStrategy, + num: int = 200): model_config = SARIXModelConfig( - model_name = "sarix_" + main_source[0].value + "_p6_4rt_theta" + theta_pooling.value + "_sigma" + sigma_pooling.value, + model_name="sarix_" + main_source[ + 0].value + "_p6_4rt_theta" + theta_pooling.value + "_sigma" + sigma_pooling.value, # data sources and adjustments for reporting issues - sources = main_source, + sources=main_source, # fit locations separately or jointly - fit_locations_separately = False, + fit_locations_separately=False, # SARI model parameters - p = 6, - P = 0, - d = 0, - D = 0, - season_period = 1, + p=6, + P=0, + d=0, + D=0, + season_period=1, # power transform applied to surveillance signals - power_transform = PowerTransform.FOURTH_ROOT, + power_transform=PowerTransform.FOURTH_ROOT, # sharing of information about parameters theta_pooling=theta_pooling, @@ -264,6 +250,7 @@ def create_test_sarix_model_config(main_source, theta_pooling: PoolingStrategy, ) return model_config + def create_test_sarix_run_config(ref_date, states, hsas, tmp_path): run_config = RunConfig( disease=Disease.FLU, @@ -271,13 +258,13 @@ def create_test_sarix_run_config(ref_date, states, hsas, tmp_path): output_root=tmp_path / "model-output", artifact_store_root=tmp_path / "artifact-store", states=states, - hsas = hsas, + hsas=hsas, max_horizon=3, q_levels=[0.025, 0.50, 0.975], q_labels=["0.025", "0.5", "0.975"], ) return run_config - + def _np_percentile_val(): return numpy.array( diff --git a/tests/unit/gbqr/test_drop_level_feats.py b/tests/unit/gbqr/test_drop_level_feats.py index 6ddbfa6..123c308 100644 --- a/tests/unit/gbqr/test_drop_level_feats.py +++ b/tests/unit/gbqr/test_drop_level_feats.py @@ -1,4 +1,6 @@ -from idmodels.preprocess import _drop_level_feats +import pandas as pd + +from idmodels.features import LevelFeatureFilter def test_drop_level_feats(): @@ -27,7 +29,7 @@ def test_drop_level_feats(): "inc_trans_cs_taylor_d1_c1_w5t_sNone_lag1", "inc_trans_cs_taylor_d1_c1_w5t_sNone_lag2", "inc_trans_cs_rollmean_w2_lag1", "inc_trans_cs_rollmean_w2_lag2", "inc_trans_cs_rollmean_w4_lag1", "inc_trans_cs_rollmean_w4_lag2", "horizon"] - + # subset of in_feats expected to be returned by _drop_level_feats: # I have manually removed any that measure local level of the surveillance signal in some way # these are 'inc_trans_cs', rolling means of that, and degree 0 coefficients ('c0') @@ -50,8 +52,8 @@ def test_drop_level_feats(): "inc_trans_cs_taylor_d1_c1_w3t_sNone_lag1", "inc_trans_cs_taylor_d1_c1_w3t_sNone_lag2", "inc_trans_cs_taylor_d1_c1_w5t_sNone_lag1", "inc_trans_cs_taylor_d1_c1_w5t_sNone_lag2", "horizon"] - - actual = _drop_level_feats(in_feats) - + + _, actual = LevelFeatureFilter().apply(pd.DataFrame(), in_feats) + assert len(actual) == len(expected) assert not set(actual) - set(expected) diff --git a/tests/unit/test_directional_wave_features.py b/tests/unit/test_directional_wave_features.py index 5879531..33a5101 100644 --- a/tests/unit/test_directional_wave_features.py +++ b/tests/unit/test_directional_wave_features.py @@ -4,72 +4,44 @@ import pandas as pd import pytest -from idmodels.preprocess import create_directional_wave_features +from idmodels.features import DirectionalWaveFeature def create_test_dataframe(): """Create a simple test dataframe with synthetic data.""" - # Create 3 locations over 5 time points dates = pd.date_range("2024-01-01", periods=5, freq="W") locations = ["01", "06", "36"] # Alabama, California, New York data = [] for loc in locations: for date in dates: - data.append({ - "location": loc, - "wk_end_date": date, - "inc_trans_cs": np.random.randn(), - "agg_level": "state", - "source": "nhsn" - }) + data.append({"location": loc, + "wk_end_date": date, + "inc_trans_cs": np.random.randn(), + "agg_level": "state", + "source": "nhsn"}) return pd.DataFrame(data) -def test_create_directional_wave_features_disabled(): - """Test that function returns empty list when disabled.""" - df = create_test_dataframe() - original_cols = set(df.columns) - - # Test with None config - df_result, feat_names = create_directional_wave_features(df, wave_config=None) - assert feat_names == [] - assert set(df_result.columns) == original_cols - - # Test with enabled=False - wave_config = {"enabled": False} - df_result, feat_names = create_directional_wave_features(df, wave_config) - assert feat_names == [] - assert set(df_result.columns) == original_cols - - def test_create_directional_wave_features_basic(): """Test basic directional wave feature generation.""" df = create_test_dataframe() - wave_config = { - "enabled": True, - "directions": ["N", "S"], - "temporal_lags": [1], - "max_distance_km": 5000, - "include_velocity": False, - "include_aggregate": False - } - - df_result, feat_names = create_directional_wave_features(df, wave_config) - - # Should have base features + lag1 for each direction - expected_feats = [ - "inc_trans_cs_wave_N", - "inc_trans_cs_wave_S", - "inc_trans_cs_wave_N_lag1", - "inc_trans_cs_wave_S_lag1" - ] + df_result, feat_names = DirectionalWaveFeature(directions=["N", "S"], + temporal_lags=[1], + max_distance_km=5000, + include_velocity=False, + include_aggregate=False, + ).apply(df, []) + + expected_feats = ["inc_trans_cs_wave_N", + "inc_trans_cs_wave_S", + "inc_trans_cs_wave_N_lag1", + "inc_trans_cs_wave_S_lag1"] assert set(feat_names) == set(expected_feats) - # Check that features were added to dataframe for feat in feat_names: assert feat in df_result.columns @@ -78,22 +50,16 @@ def test_create_directional_wave_features_all_directions(): """Test with all 8 directions.""" df = create_test_dataframe() - wave_config = { - "enabled": True, - "directions": ["N", "NE", "E", "SE", "S", "SW", "W", "NW"], - "temporal_lags": [], # No lags for simplicity - "max_distance_km": 5000, - "include_velocity": False, - "include_aggregate": False - } + _, feat_names = DirectionalWaveFeature(directions=["N", "NE", "E", "SE", "S", "SW", "W", "NW"], + temporal_lags=[], + max_distance_km=5000, + include_velocity=False, + include_aggregate=False, + ).apply(df, []) - df_result, feat_names = create_directional_wave_features(df, wave_config) - - # Should have 8 base features (one per direction) assert len(feat_names) == 8 - expected_directions = ["N", "NE", "E", "SE", "S", "SW", "W", "NW"] - for direction in expected_directions: + for direction in ["N", "NE", "E", "SE", "S", "SW", "W", "NW"]: assert f"inc_trans_cs_wave_{direction}" in feat_names @@ -101,18 +67,13 @@ def test_create_directional_wave_features_with_aggregate(): """Test with aggregate feature enabled.""" df = create_test_dataframe() - wave_config = { - "enabled": True, - "directions": ["N", "S"], - "temporal_lags": [], - "max_distance_km": 5000, - "include_velocity": False, - "include_aggregate": True - } - - df_result, feat_names = create_directional_wave_features(df, wave_config) + _, feat_names = DirectionalWaveFeature(directions=["N", "S"], + temporal_lags=[], + max_distance_km=5000, + include_velocity=False, + include_aggregate=True, + ).apply(df, []) - # Should have N, S, and avg assert "inc_trans_cs_wave_N" in feat_names assert "inc_trans_cs_wave_S" in feat_names assert "inc_trans_cs_wave_avg" in feat_names @@ -122,23 +83,16 @@ def test_create_directional_wave_features_with_lags(): """Test with multiple temporal lags.""" df = create_test_dataframe() - wave_config = { - "enabled": True, - "directions": ["N"], - "temporal_lags": [1, 2], - "max_distance_km": 5000, - "include_velocity": False, - "include_aggregate": False - } + _, feat_names = DirectionalWaveFeature(directions=["N"], + temporal_lags=[1, 2], + max_distance_km=5000, + include_velocity=False, + include_aggregate=False, + ).apply(df, []) - df_result, feat_names = create_directional_wave_features(df, wave_config) - - # Should have base + lag1 + lag2 - expected_feats = [ - "inc_trans_cs_wave_N", - "inc_trans_cs_wave_N_lag1", - "inc_trans_cs_wave_N_lag2" - ] + expected_feats = ["inc_trans_cs_wave_N", + "inc_trans_cs_wave_N_lag1", + "inc_trans_cs_wave_N_lag2"] assert set(feat_names) == set(expected_feats) @@ -147,24 +101,17 @@ def test_create_directional_wave_features_with_velocity(): """Test with velocity features enabled.""" df = create_test_dataframe() - wave_config = { - "enabled": True, - "directions": ["N"], - "temporal_lags": [1], - "max_distance_km": 5000, - "include_velocity": True, - "include_aggregate": False - } - - df_result, feat_names = create_directional_wave_features(df, wave_config) + df_result, feat_names = DirectionalWaveFeature(directions=["N"], + temporal_lags=[1], + max_distance_km=5000, + include_velocity=True, + include_aggregate=False, + ).apply(df, []) - # Should have base + lag1 + velocity assert "inc_trans_cs_wave_N" in feat_names assert "inc_trans_cs_wave_N_lag1" in feat_names assert "inc_trans_cs_wave_N_velocity" in feat_names - # Check that velocity is computed correctly (current - lag1) - # For locations with enough history for loc in df_result["location"].unique(): loc_data = df_result[df_result["location"] == loc].reset_index(drop=True) for i in range(1, len(loc_data)): @@ -181,26 +128,19 @@ def test_create_directional_wave_features_lag_semantics(): """Test that lag1 refers to t-1, lag2 to t-2.""" df = create_test_dataframe() - wave_config = { - "enabled": True, - "directions": ["N"], - "temporal_lags": [1, 2], - "max_distance_km": 5000, - "include_velocity": False, - "include_aggregate": False - } + df_result, _ = DirectionalWaveFeature(directions=["N"], + temporal_lags=[1, 2], + max_distance_km=5000, + include_velocity=False, + include_aggregate=False, + ).apply(df, []) - df_result, feat_names = create_directional_wave_features(df, wave_config) - - # Check lag semantics for one location loc_data = df_result[df_result["location"] == "01"].sort_values("wk_end_date").reset_index(drop=True) - # At index i, lag1 should equal base value at index i-1 for i in range(1, len(loc_data)): - base_prev = loc_data.loc[i-1, "inc_trans_cs_wave_N"] + base_prev = loc_data.loc[i - 1, "inc_trans_cs_wave_N"] lag1_curr = loc_data.loc[i, "inc_trans_cs_wave_N_lag1"] - # If both exist and are not NaN, they should match if not pd.isna(base_prev) and not pd.isna(lag1_curr): assert abs(base_prev - lag1_curr) < 1e-6 @@ -210,18 +150,13 @@ def test_create_directional_wave_features_preserves_index(): df = create_test_dataframe() original_index = df.index.tolist() - wave_config = { - "enabled": True, - "directions": ["N"], - "temporal_lags": [], - "max_distance_km": 5000, - "include_velocity": False, - "include_aggregate": False - } - - df_result, _ = create_directional_wave_features(df, wave_config) + df_result, _ = DirectionalWaveFeature(directions=["N"], + temporal_lags=[], + max_distance_km=5000, + include_velocity=False, + include_aggregate=False, + ).apply(df, []) - # Index should be preserved assert df_result.index.tolist() == original_index @@ -229,86 +164,66 @@ def test_create_directional_wave_features_invalid_direction(): """Test that invalid directions raise ValueError.""" df = create_test_dataframe() - wave_config = { - "enabled": True, - "directions": ["N", "INVALID"], - "temporal_lags": [], - "max_distance_km": 5000, - "include_velocity": False, - "include_aggregate": False - } - with pytest.raises(ValueError, match="Invalid direction"): - create_directional_wave_features(df, wave_config) + DirectionalWaveFeature(directions=["N", "INVALID"], + temporal_lags=[], + max_distance_km=5000, + include_velocity=False, + include_aggregate=False, + ).apply(df, []) def test_create_directional_wave_features_multiple_agg_levels(): """Test that multiple agg_levels raise ValueError.""" df = create_test_dataframe() - # Add a row with different agg_level - df.loc[len(df)] = { - "location": "01001", - "wk_end_date": pd.Timestamp("2024-01-01"), - "inc_trans_cs": 0.5, - "agg_level": "county", - "source": "nhsn" - } - - wave_config = { - "enabled": True, - "directions": ["N"], - "temporal_lags": [], - "max_distance_km": 5000, - "include_velocity": False, - "include_aggregate": False - } + df.loc[len(df)] = {"location": "01001", + "wk_end_date": pd.Timestamp("2024-01-01"), + "inc_trans_cs": 0.5, + "agg_level": "county", + "source": "nhsn"} with pytest.raises(ValueError, match="Multiple aggregation levels"): - create_directional_wave_features(df, wave_config) + DirectionalWaveFeature(directions=["N"], + temporal_lags=[], + max_distance_km=5000, + include_velocity=False, + include_aggregate=False, + ).apply(df, []) def test_create_directional_wave_features_missing_coordinates(): """Test that missing coordinates raise ValueError.""" df = create_test_dataframe() - # Add a location without coordinates - df.loc[len(df)] = { - "location": "FAKE99", - "wk_end_date": pd.Timestamp("2024-01-01"), - "inc_trans_cs": 0.5, - "agg_level": "state", - "source": "nhsn" - } - - wave_config = { - "enabled": True, - "directions": ["N"], - "temporal_lags": [], - "max_distance_km": 5000, - "include_velocity": False, - "include_aggregate": False - } + df.loc[len(df)] = {"location": "FAKE99", + "wk_end_date": pd.Timestamp("2024-01-01"), + "inc_trans_cs": 0.5, + "agg_level": "state", + "source": "nhsn"} with pytest.raises(ValueError, match="Missing coordinates"): - create_directional_wave_features(df, wave_config) + DirectionalWaveFeature(directions=["N"], + temporal_lags=[], + max_distance_km=5000, + include_velocity=False, + include_aggregate=False, + ).apply(df, []) def test_create_directional_wave_features_default_config(): """Test that default configuration values work.""" df = create_test_dataframe() - # Minimal config - should use defaults - wave_config = { - "enabled": True - } - - df_result, feat_names = create_directional_wave_features(df, wave_config) + _, feat_names = DirectionalWaveFeature(directions=["N", "NE", "E", "SE", "S", "SW", "W", "NW"], + temporal_lags=[1, 2], + max_distance_km=1000, + include_velocity=False, + include_aggregate=True, + ).apply(df, []) - # Should use default 8 directions assert len([f for f in feat_names if "lag" not in f and "velocity" not in f]) == 9 # 8 directions + avg - # Should have lag1 and lag2 (default temporal_lags=[1, 2]) assert any("lag1" in f for f in feat_names) assert any("lag2" in f for f in feat_names) @@ -317,29 +232,19 @@ def test_create_directional_wave_features_feature_count(): """Test that the correct number of features is generated.""" df = create_test_dataframe() - wave_config = { - "enabled": True, - "directions": ["N", "S", "E", "W"], # 4 directions - "temporal_lags": [1, 2], # 2 lags - "max_distance_km": 5000, - "include_velocity": True, # Add velocity - "include_aggregate": True # Add aggregate - } - - df_result, feat_names = create_directional_wave_features(df, wave_config) - - # Expected: - # - 4 base directional features - # - 1 aggregate feature - # - (4 + 1) * 2 lags = 10 lag features - # - 5 velocity features (4 directions + 1 aggregate) - # Total: 4 + 1 + 10 + 5 = 20 + _, feat_names = DirectionalWaveFeature(directions=["N", "S", "E", "W"], + temporal_lags=[1, 2], + max_distance_km=5000, + include_velocity=True, + include_aggregate=True, + ).apply(df, []) + + # 4 base directional + 1 aggregate + (4+1)*2 lags + 5 velocity = 20 assert len(feat_names) == 20 def test_create_directional_wave_features_no_neighbors(): """Test behavior when locations have no neighbors in a direction.""" - # Create single location df = pd.DataFrame([ { "location": "01", @@ -350,17 +255,12 @@ def test_create_directional_wave_features_no_neighbors(): } ]) - wave_config = { - "enabled": True, - "directions": ["N"], - "temporal_lags": [], - "max_distance_km": 10, # Very small distance - no neighbors - "include_velocity": False, - "include_aggregate": False - } - - df_result, feat_names = create_directional_wave_features(df, wave_config) + df_result, feat_names = DirectionalWaveFeature(directions=["N"], + temporal_lags=[], + max_distance_km=10, + include_velocity=False, + include_aggregate=False, + ).apply(df, []) - # Feature should exist but be NaN (no neighbors) assert "inc_trans_cs_wave_N" in feat_names assert pd.isna(df_result.loc[0, "inc_trans_cs_wave_N"]) diff --git a/tests/unit/test_features.py b/tests/unit/test_features.py new file mode 100644 index 0000000..e317a04 --- /dev/null +++ b/tests/unit/test_features.py @@ -0,0 +1,182 @@ +"""Unit tests for feature classes in idmodels.features.""" + +import numpy as np +import pandas as pd +import pytest + +from idmodels.features import ( + FeaturePipeline, + HorizonTargetFeature, + LagFeature, + LevelFeatureFilter, + OneHotEncodingFeature, +) + + +def make_df(n_weeks=20, locations=("01", "06", "36"), seed=0): + rng = np.random.default_rng(seed) + rows = [] + for loc in locations: + for week in range(1, n_weeks + 1): + rows.append({"source": "nhsn", + "location": loc, + "season": "2023-24", + "season_week": week, + "wk_end_date": pd.Timestamp("2023-10-01") + pd.Timedelta(weeks=week - 1), + "inc_trans_cs": rng.normal(0.0, 0.5), }) + return pd.DataFrame(rows) + + +class TestOneHotEncodingFeature: + def test_adds_dummy_columns(self): + df = make_df() + feat = OneHotEncodingFeature(columns=["source"]) + df_out, feat_names = feat.apply(df.copy(), []) + assert "source_nhsn" in feat_names + assert "source_nhsn" in df_out.columns + + + def test_updates_feat_names(self): + df = make_df() + initial = ["inc_trans_cs"] + feat = OneHotEncodingFeature(columns=["source"]) + _, feat_names = feat.apply(df.copy(), list(initial)) + assert "inc_trans_cs" in feat_names + assert "source_nhsn" in feat_names + + +class TestLagFeature: + def test_creates_lag_columns(self): + df = make_df() + feat = LagFeature(columns=["inc_trans_cs"], lags=[1, 2]) + df_out, feat_names = feat.apply(df.copy(), ["inc_trans_cs"]) + assert "inc_trans_cs_lag1" in df_out.columns + assert "inc_trans_cs_lag2" in df_out.columns + assert "inc_trans_cs_lag1" in feat_names + assert "inc_trans_cs_lag2" in feat_names + + + def test_lag_semantics(self): + df = make_df(n_weeks=10, locations=("01",)) + feat = LagFeature(columns=["inc_trans_cs"], lags=[1]) + df_out, _ = feat.apply(df.copy(), ["inc_trans_cs"]) + df_out = df_out.sort_values("wk_end_date").reset_index(drop=True) + for i in range(1, len(df_out)): + orig = df_out.loc[i - 1, "inc_trans_cs"] + lagged = df_out.loc[i, "inc_trans_cs_lag1"] + if not pd.isna(lagged): + assert abs(orig - lagged) < 1e-12 + + + def test_raises_if_columns_none(self): + feat = LagFeature(columns=None, lags=[1]) + with pytest.raises(ValueError, match="must be resolved"): + feat.apply(pd.DataFrame(), []) + + +class TestHorizonTargetFeature: + def test_adds_horizon_column(self): + df = make_df() + feat = HorizonTargetFeature(column="inc_trans_cs", max_horizon=3) + df_out, feat_names = feat.apply(df.copy(), ["inc_trans_cs"]) + assert "horizon" in df_out.columns + assert "horizon" in feat_names + + + def test_adds_delta_target(self): + df = make_df() + feat = HorizonTargetFeature(column="inc_trans_cs", max_horizon=2) + df_out, _ = feat.apply(df.copy(), ["inc_trans_cs"]) + assert "delta_target" in df_out.columns + + + def test_expands_rows(self): + n_rows = len(make_df()) + df = make_df() + feat = HorizonTargetFeature(column="inc_trans_cs", max_horizon=4) + df_out, _ = feat.apply(df.copy(), ["inc_trans_cs"]) + assert len(df_out) >= n_rows + + +class TestLevelFeatureFilter: + def test_removes_inc_trans_cs(self): + feat_names = ["inc_trans_cs", "season_week", "log_pop"] + feat = LevelFeatureFilter() + _, out = feat.apply(pd.DataFrame(), list(feat_names)) + assert "inc_trans_cs" not in out + assert "season_week" in out + + + def test_removes_rollmean(self): + feat_names = ["inc_trans_cs_rollmean_w2", "inc_trans_cs_rollmean_w4", "log_pop"] + feat = LevelFeatureFilter() + _, out = feat.apply(pd.DataFrame(), list(feat_names)) + assert "inc_trans_cs_rollmean_w2" not in out + assert "log_pop" in out + + + def test_removes_taylor_c0_not_c1(self): + feat_names = ["inc_trans_cs_taylor_d2_c0_w4t_sNone", + "inc_trans_cs_taylor_d2_c1_w4t_sNone", ] + feat = LevelFeatureFilter() + _, out = feat.apply(pd.DataFrame(), list(feat_names)) + assert "inc_trans_cs_taylor_d2_c0_w4t_sNone" not in out + assert "inc_trans_cs_taylor_d2_c1_w4t_sNone" in out + + +class TestFeaturePipeline: + def test_accumulator_resolved_for_null_lag(self): + """LagFeature(columns=None) should resolve to columns added since last LagFeature.""" + df = make_df() + pipeline = FeaturePipeline( + features=[ + OneHotEncodingFeature(columns=["source"]), + LagFeature(columns=None, lags=[1]), + ], + initial_feat_names=["inc_trans_cs"], + ) + df_out, feat_names = pipeline.apply(df.copy()) + assert "source_nhsn_lag1" in feat_names + + + def test_initial_feats_not_in_accumulator(self): + """Columns in initial_feat_names should not be lagged by LagFeature(columns=None).""" + df = make_df() + pipeline = FeaturePipeline( + features=[LagFeature(columns=None, lags=[1])], + initial_feat_names=["inc_trans_cs"], + ) + df_out, feat_names = pipeline.apply(df.copy()) + # No new columns added before LagFeature, so no lags created + assert "inc_trans_cs_lag1" not in feat_names + + + def test_accumulator_resets_after_lag(self): + """After a LagFeature step, accumulator resets; next LagFeature only picks up new cols.""" + df = make_df() + pipeline = FeaturePipeline( + features=[OneHotEncodingFeature(columns=["source"]), + LagFeature(columns=None, lags=[1]), + OneHotEncodingFeature(columns=["season"]), + LagFeature(columns=None, lags=[1])], + initial_feat_names=["inc_trans_cs"], + ) + df_out, feat_names = pipeline.apply(df.copy()) + # First batch lagged: source_nhsn → source_nhsn_lag1 + assert "source_nhsn_lag1" in feat_names + # Second batch lagged: only season dummy (added after first lag step) + # source_nhsn should NOT be re-lagged + source_lag_count = sum(1 for f in feat_names if f.startswith("source_nhsn_lag")) + assert source_lag_count == 1 + + + def test_explicit_columns_lag(self): + """LagFeature with explicit columns= ignores accumulator.""" + df = make_df() + pipeline = FeaturePipeline( + features=[LagFeature(columns=["inc_trans_cs"], lags=[1, 2])], + initial_feat_names=["inc_trans_cs"], + ) + df_out, feat_names = pipeline.apply(df.copy()) + assert "inc_trans_cs_lag1" in feat_names + assert "inc_trans_cs_lag2" in feat_names diff --git a/tests/unit/test_transforms.py b/tests/unit/test_transforms.py new file mode 100644 index 0000000..7eb9f2f --- /dev/null +++ b/tests/unit/test_transforms.py @@ -0,0 +1,137 @@ +"""Unit tests for transform classes in idmodels.transforms.""" + +import numpy as np +import pandas as pd + +from idmodels.constants import POWER_TRANSFORM_OFFSET +from idmodels.transforms import ( + CenterScaleTransform, + ComposedTransform, + FourthRootTransform, + IdentityTransform, +) + + +def make_df(n_per_group=20, n_groups=2, seed=0): + rng = np.random.default_rng(seed) + rows = [] + for loc in range(n_groups): + for week in range(1, n_per_group + 1): + rows.append({"source": "nhsn", + "location": str(loc).zfill(2), + "season_week": week, + "inc": max(0.0, rng.normal(0.5, 0.2))}) + return pd.DataFrame(rows) + + +class TestFourthRootTransform: + def test_apply_adds_inc_trans(self): + df = make_df() + t = FourthRootTransform() + out = t.apply(df.copy()) + assert "inc_trans" in out.columns + + + def test_apply_values(self): + df = make_df() + t = FourthRootTransform() + out = t.apply(df.copy()) + expected = (df["inc"] + POWER_TRANSFORM_OFFSET) ** 0.25 + np.testing.assert_allclose(out["inc_trans"].values, expected.values) + + + def test_roundtrip(self): + df = make_df() + t = FourthRootTransform() + out = t.apply(df.copy()) + recovered = t.invert(out["inc_trans"].values, out) + np.testing.assert_allclose(recovered, df["inc"].values, atol=1e-10) + + + def test_invert_clips_at_zero(self): + t = FourthRootTransform() + result = t.invert(np.array([-5.0, 0.0]), pd.DataFrame()) + assert (result >= -POWER_TRANSFORM_OFFSET).all() + + + def test_additive_shift(self): + df = make_df() + shift = 0.316 + t = FourthRootTransform(additive_shift=shift) + out = t.apply(df.copy()) + expected = (df["inc"] + shift + POWER_TRANSFORM_OFFSET) ** 0.25 + np.testing.assert_allclose(out["inc_trans"].values, expected.values) + + +class TestIdentityTransform: + def test_apply_adds_inc_trans(self): + df = make_df() + t = IdentityTransform() + out = t.apply(df.copy()) + assert "inc_trans" in out.columns + + + def test_roundtrip(self): + df = make_df() + t = IdentityTransform() + out = t.apply(df.copy()) + recovered = t.invert(out["inc_trans"].values, out) + np.testing.assert_allclose(recovered, df["inc"].values, atol=1e-10) + + +class TestCenterScaleTransform: + def test_apply_adds_factor_columns(self): + df = make_df(n_per_group=52) + df["inc_trans"] = (df["inc"] + POWER_TRANSFORM_OFFSET) ** 0.25 + t = CenterScaleTransform() + out = t.apply(df.copy()) + for col in ("inc_trans_cs", "inc_trans_scale_factor", "inc_trans_center_factor"): + assert col in out.columns + + + def test_invert_roundtrip(self): + df = make_df(n_per_group=52) + df["inc_trans"] = (df["inc"] + POWER_TRANSFORM_OFFSET) ** 0.25 + t = CenterScaleTransform() + out = t.apply(df.copy()) + recovered = t.invert(out["inc_trans_cs"].values, out) + np.testing.assert_allclose(recovered, out["inc_trans"].values, atol=1e-10) + + + def test_invert_uses_offset(self): + """Verify invert uses (scale_factor + 0.01), not just scale_factor.""" + context = pd.DataFrame({ + "inc_trans_scale_factor": [1.0], + "inc_trans_center_factor": [0.0], + }) + t = CenterScaleTransform() + result = t.invert(np.array([1.0]), context) + assert abs(result[0] - 1.01) < 1e-12 + + + def test_only_in_season_rows_used_for_scale(self): + df = make_df(n_per_group=52) + df["inc_trans"] = (df["inc"] + POWER_TRANSFORM_OFFSET) ** 0.25 + t = CenterScaleTransform() + out = t.apply(df.copy()) + # Out-of-season rows should still have scale factors (broadcast from in-season) + assert out["inc_trans_scale_factor"].notna().all() + + +class TestComposedTransform: + def test_roundtrip_fourth_root_then_center_scale(self): + df = make_df(n_per_group=52) + t = ComposedTransform([FourthRootTransform(), CenterScaleTransform()]) + out = t.apply(df.copy()) + assert "inc_trans_cs" in out.columns + + # invert() reverses both transforms: CenterScale then FourthRoot → back to inc + recovered_inc = t.invert(out["inc_trans_cs"].values, out) + np.testing.assert_allclose(recovered_inc, df["inc"].values, atol=1e-10) + + + def test_apply_calls_transforms_in_order(self): + df = make_df() + t = ComposedTransform([FourthRootTransform()]) + out = t.apply(df.copy()) + assert "inc_trans" in out.columns diff --git a/uv.lock b/uv.lock index c028cce..7439aaf 100644 --- a/uv.lock +++ b/uv.lock @@ -614,7 +614,7 @@ requires-dist = [ { name = "pre-commit", marker = "extra == 'dev'" }, { name = "pytest", marker = "extra == 'dev'" }, { name = "ruff", marker = "extra == 'dev'" }, - { name = "sarix", git = "https://github.com/reichlab/sarix" }, + { name = "sarix", git = "https://github.com/reichlab/sarix?rev=35eea2379a9790e0457b1aed41d13509e5d5056f" }, { name = "scikit-learn" }, { name = "timeseriesutils", git = "https://github.com/reichlab/timeseriesutils" }, { name = "tqdm" }, @@ -1635,7 +1635,7 @@ wheels = [ [[package]] name = "sarix" version = "0.2.0" -source = { git = "https://github.com/reichlab/sarix#35eea2379a9790e0457b1aed41d13509e5d5056f" } +source = { git = "https://github.com/reichlab/sarix?rev=35eea2379a9790e0457b1aed41d13509e5d5056f#35eea2379a9790e0457b1aed41d13509e5d5056f" } dependencies = [ { name = "jax" }, { name = "matplotlib" },