Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ classifiers = [
dynamic = ["version"]

dependencies = [
"iddata @ git+https://github.com/reichlab/iddata",
"iddata @ git+https://github.com/reichlab/iddata@7b86ad0e513423faa8d327426c18350dfdfa07f0",
"lightgbm",
"numpy",
"pandas~=2.0", # pandas 3.0 breaks compatibility; remove cap once validated
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ frozenlist==1.5.0
# aiosignal
fsspec==2024.10.0
# via s3fs
iddata @ git+https://github.com/reichlab/iddata@5a7e74d7823d39b8a8ef6334c5191e440bc669d8
iddata @ git+https://github.com/reichlab/iddata@7b86ad0e513423faa8d327426c18350dfdfa07f0
# via idmodels (pyproject.toml)
identify==2.6.1
# via pre-commit
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ frozenlist==1.5.0
# aiosignal
fsspec==2024.10.0
# via s3fs
iddata @ git+https://github.com/reichlab/iddata@5a7e74d7823d39b8a8ef6334c5191e440bc669d8
iddata @ git+https://github.com/reichlab/iddata@7b86ad0e513423faa8d327426c18350dfdfa07f0
# via idmodels (pyproject.toml)
idna==3.10
# via yarl
Expand Down
26 changes: 26 additions & 0 deletions src/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pickle
from pathlib import Path

from idmodels.sarix import SARIXModel


def main():
pkl_dir = Path("/Users/cornell/IdeaProjects/operational-models/covid_gbqr/")
# pkl_dir = Path('/Users/cornell/IdeaProjects/operational-models/covid_ar6_pooled/')
pkl_mc_rc_pairs = [
# ('2025-08-02-model_config.pkl', '2025-08-02-run_config.pkl'), # works -> 2025-08-02-UMass-gbqr.csv
("2025-08-09-model_config.pkl", "2025-08-09-run_config.pkl"), # fails -> 2025-08-09-UMass-gbqr.csv
]
for model_config_file_name, run_config_file_name in pkl_mc_rc_pairs:
with open(pkl_dir / model_config_file_name, "rb") as mc_fp, open(pkl_dir / run_config_file_name, "rb") as rc_fp:
model_config = pickle.load(mc_fp)
run_config = pickle.load(rc_fp)
print("*", model_config_file_name, ",", run_config_file_name)
print("yy", model_config)
model_config.x = []
model = SARIXModel(model_config)
model.run(run_config)


if __name__ == "__main__":
main()
31 changes: 21 additions & 10 deletions src/idmodels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
from iddata.enums import Disease, SourceType

from idmodels.config import (
DataSource,
Disease,
GBQRModelConfig,
PoolingStrategy,
PowerTransform,
RunConfig,
SARIXFourierModelConfig,
SARIXModelConfig,
GBQRModelConfig,
PoolingStrategy,
PowerTransform,
RunConfig,
SARIXFourierModelConfig,
SARIXModelConfig,
)
from idmodels.gbqr import GBQRModel
from idmodels.sarix import SARIXFourierModel, SARIXModel

__all__ = ["DataSource", "Disease", "GBQRModel", "GBQRModelConfig", "PoolingStrategy", "PowerTransform", "RunConfig",
"SARIXFourierModel", "SARIXFourierModelConfig", "SARIXModel", "SARIXModelConfig"]
__all__ = [
"Disease",
"GBQRModel",
"GBQRModelConfig",
"PoolingStrategy",
"PowerTransform",
"RunConfig",
"SARIXFourierModel",
"SARIXFourierModelConfig",
"SARIXModel",
"SARIXModelConfig",
"SourceType",
]

__version__ = "1.3.1"
33 changes: 8 additions & 25 deletions src/idmodels/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,10 @@
from enum import Enum
from pathlib import Path


class DataSource(str, Enum):
NHSN = "nhsn"
NSSP = "nssp"
FLUSURVNET = "flusurvnet"
ILINET = "ilinet"


class Disease(str, Enum):
FLU = "flu"
COVID = "covid"
RSV = "rsv"
from iddata.enums import (
Disease, # used internally for RunConfig.disease; import from iddata.enums directly in callers
SourceType, # re-exported for callers: from idmodels.config import SourceType
)


class PowerTransform(str, Enum):
Expand All @@ -30,30 +22,22 @@ class PoolingStrategy(str, Enum):

@dataclass
class ModelConfig(ABC):
"""
Abstract base for model configuration.

Holds settings that describe *what* model to run and how it processes data (sources, transforms, pooling).
Not instantiated directly - use :class:`SARIXModelConfig` or :class:`GBQRModelConfig`.
"""
"""Abstract base for model configuration."""

model_name: str
sources: list[DataSource]
sources: list[SourceType]
fit_locations_separately: bool
power_transform: PowerTransform


def __post_init__(self):
if type(self) is ModelConfig:
raise TypeError("ModelConfig is abstract - use SARIXModelConfig or GBQRModelConfig")


@dataclass
class RunConfig:
"""
Run configuration.

Holds settings that describe a single execution: which disease, which locations, output paths, quantile levels, etc.
"""
"""Run configuration: disease, locations, output paths, quantile levels."""

disease: Disease
ref_date: datetime.date
Expand Down Expand Up @@ -95,7 +79,6 @@ class GBQRModelConfig(ModelConfig):
reporting_adj: bool = False
save_feat_importance: bool = False

# directional wave features (disabled by default)
use_directional_waves: bool = False
wave_directions: list[str] = field(default_factory=lambda: ["N", "NE", "E", "SE", "S", "SW", "W", "NW"])
wave_temporal_lags: list[int] = field(default_factory=lambda: [1, 2])
Expand Down
6 changes: 6 additions & 0 deletions src/idmodels/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
NHSN_INCIDENCE_SHIFT: float = 0.75 ** 4 # ≈ 0.316

POWER_TRANSFORM_OFFSET: float = 0.01

IN_SEASON_WEEK_MIN: int = 10
IN_SEASON_WEEK_MAX: int = 45
Loading
Loading