@@ -12,6 +12,10 @@ class SARIXModel():
1212 def __init__ (self , model_config ):
1313 self .model_config = model_config
1414
15+ def _get_extra_sarix_params (self , df ):
16+ """Return extra parameters to pass to SARIX constructor. Returns empty dict by default."""
17+ return {}
18+
1519 def run (self , run_config ):
1620 fdl = DiseaseDataLoader ()
1721 df = fdl .load_data (nhsn_kwargs = {"as_of" : run_config .ref_date , "disease" : run_config .disease },
@@ -30,11 +34,14 @@ def run(self, run_config):
3034 on = "season" ) \
3135 .assign (delta_xmas = lambda x : x ["season_week" ] - x ["xmas_week" ])
3236 df ["xmas_spike" ] = np .maximum (3 - np .abs (df ["delta_xmas" ]), 0 )
33-
37+
3438 xy_colnames = self .model_config .x + ["inc_trans_cs" ]
3539 df = df .query ("wk_end_date >= '2022-10-01'" ).interpolate ()
3640 batched_xy = df [xy_colnames ].values .reshape (len (df ["location" ].unique ()), - 1 , len (xy_colnames ))
37-
41+
42+ # Get any extra parameters for the SARIX constructor
43+ extra_params = self ._get_extra_sarix_params (df )
44+
3845 sarix_fit_all_locs_theta_pooled = sarix .SARIX (
3946 xy = batched_xy ,
4047 p = self .model_config .p ,
@@ -48,7 +55,8 @@ def run(self, run_config):
4855 forecast_horizon = run_config .max_horizon ,
4956 num_warmup = run_config .num_warmup ,
5057 num_samples = run_config .num_samples ,
51- num_chains = run_config .num_chains
58+ num_chains = run_config .num_chains ,
59+ ** extra_params
5260 )
5361
5462 pred_qs = _np_percentile (sarix_fit_all_locs_theta_pooled .predictions [..., :, :, 0 ],
@@ -93,9 +101,34 @@ def run(self, run_config):
93101 run_config = run_config ,
94102 model_config = self .model_config
95103 )
104+ # Ensure output_type_id is string to avoid pandas inferring it as float when reading
105+ preds_df ["output_type_id" ] = preds_df ["output_type_id" ].astype (str )
96106 preds_df .to_csv (save_path , index = False )
97107
98108
109+ class SARIXFourierModel (SARIXModel ):
110+ """
111+ SARIX model with Fourier seasonality terms.
112+
113+ Adds annual seasonal patterns using Fourier harmonics to the base SARIX model.
114+
115+ Required model_config parameters:
116+ - fourier_K: Number of Fourier harmonic pairs (int)
117+ - fourier_pooling: How to share Fourier coefficients across locations ('none' or 'shared')
118+ """
119+ def _get_extra_sarix_params (self , df ):
120+ """Return Fourier-specific parameters for SARIX constructor."""
121+ # Extract day-of-year from dates for Fourier features
122+ # Take the first location's dates (same for all locations after reshaping)
123+ day_of_year = df .groupby ("location" )["wk_end_date" ].apply (lambda x : x .dt .dayofyear .values ).iloc [0 ]
124+
125+ return {
126+ "day_of_year" : day_of_year ,
127+ "fourier_K" : self .model_config .fourier_K ,
128+ "fourier_pooling" : self .model_config .fourier_pooling
129+ }
130+
131+
99132def _np_percentile (predictions , q_levels , axis ):
100133 """
101134 Simple helper function to ease patching from unit tests.
0 commit comments