-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfig.py
More file actions
87 lines (67 loc) · 2.22 KB
/
config.py
File metadata and controls
87 lines (67 loc) · 2.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import datetime
from abc import ABC
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from iddata.enums import (
Disease, # used internally for RunConfig.disease; import from iddata.enums directly in callers
SourceType, # re-exported for callers: from idmodels.config import SourceType
)
class PowerTransform(str, Enum):
FOURTH_ROOT = "4rt"
NONE = "none"
class PoolingStrategy(str, Enum):
NONE = "none"
SHARED = "shared"
@dataclass
class ModelConfig(ABC):
"""Abstract base for model configuration."""
model_name: str
sources: list[SourceType]
fit_locations_separately: bool
power_transform: PowerTransform
def __post_init__(self):
if type(self) is ModelConfig:
raise TypeError("ModelConfig is abstract - use SARIXModelConfig or GBQRModelConfig")
@dataclass
class RunConfig:
"""Run configuration: disease, locations, output paths, quantile levels."""
disease: Disease
ref_date: datetime.date
output_root: Path
artifact_store_root: Path | None
max_horizon: int
states: list[str]
hsas: list[str]
q_levels: list[float]
q_labels: list[str]
@dataclass
class SARIXModelConfig(ModelConfig):
p: int = 0
P: int = 0
d: int = 0
D: int = 0
season_period: int = 1
theta_pooling: PoolingStrategy = PoolingStrategy.NONE
sigma_pooling: PoolingStrategy = PoolingStrategy.NONE
x: list = field(default_factory=list)
num_warmup: int = 2000
num_samples: int = 2000
num_chains: int = 1
@dataclass
class SARIXFourierModelConfig(SARIXModelConfig):
fourier_K: int = 1
fourier_pooling: PoolingStrategy = PoolingStrategy.NONE
@dataclass
class GBQRModelConfig(ModelConfig):
incl_level_feats: bool = True
num_bags: int = 100
bag_frac_samples: float = 0.7
reporting_adj: bool = False
save_feat_importance: bool = False
use_directional_waves: bool = False
wave_directions: list[str] = field(default_factory=lambda: ["N", "NE", "E", "SE", "S", "SW", "W", "NW"])
wave_temporal_lags: list[int] = field(default_factory=lambda: [1, 2])
wave_max_distance_km: float = 1000.0
wave_include_velocity: bool = False
wave_include_aggregate: bool = True