11"""Script running Covvfit inference on the data."""
2+ import warnings
23from pathlib import Path
34from typing import Annotated , NamedTuple , Optional
45
@@ -117,31 +118,45 @@ class PlotDimensions(pydantic.BaseModel):
117118 panel_height : float = 1.5
118119 dpi : int = 350
119120
120- wspace : float = 1.0
121- hspace : float = 0.5
121+ wspace : float = pydantic .Field (
122+ default = 1.0 , help = "Horizontal (width) spacing between figure panels."
123+ )
124+ hspace : float = pydantic .Field (
125+ default = 0.5 , help = "Vertical (height) spacing between figure panels."
126+ )
122127
123- left : float = 1.0
124- right : float = 1.5
125- top : float = 0.7
126- bottom : float = 0.5
128+ left : float = pydantic . Field ( default = 1.0 , help = "Left margin in the figure." )
129+ right : float = pydantic . Field ( default = 1.5 , help = "Right margin in the figure." )
130+ top : float = pydantic . Field ( default = 0.7 , help = "Top margin in the figure." )
131+ bottom : float = pydantic . Field ( default = 0.5 , help = "Bottom margin in the figure." )
127132
128133
129134class PlotSettings (pydantic .BaseModel ):
130135 dimensions : PlotDimensions = pydantic .Field (default_factory = PlotDimensions )
131136 prediction : PredictionRegion = pydantic .Field (default_factory = PredictionRegion )
132137 variant_colors : dict [str , str ] = pydantic .Field (
133- default_factory = lambda : plot_ts .COLORS_COVSPECTRUM
138+ default_factory = lambda : plot_ts .COLORS_COVSPECTRUM ,
139+ help = "Dictionary mapping variants to colors in the plot." ,
140+ )
141+ time_spacing : pydantic .conint (ge = 1 ) = pydantic .Field (
142+ default = 1 , help = "Spacing between ticks on the time axis (in months)."
134143 )
135144
136145
137146class Config (pydantic .BaseModel ):
138- variants : list [str ] = pydantic .Field (default_factory = lambda : [])
139- plot : PlotSettings = pydantic .Field (default_factory = PlotSettings )
147+ variants : list [str ] = pydantic .Field (
148+ default_factory = lambda : [],
149+ help = "List of variants to be included in the analysis." ,
150+ )
151+ plot : PlotSettings = pydantic .Field (
152+ default_factory = PlotSettings , help = "Plot settings."
153+ )
140154
141155
142156def _parse_config (
143157 config_path : Optional [str ],
144158 variants : Optional [list [str ]],
159+ time_spacing : Optional [int ],
145160) -> Config :
146161 if config_path is None :
147162 config = Config ()
@@ -153,6 +168,9 @@ def _parse_config(
153168 if variants is not None :
154169 config .variants = variants
155170
171+ if time_spacing is not None :
172+ config .plot .time_spacing = time_spacing
173+
156174 if len (config .variants ) == 0 :
157175 raise ValueError ("No variants have been specified." )
158176
@@ -195,6 +213,13 @@ def infer(
195213 help = "Number of future days for which abundance prediction should be generated" ,
196214 ),
197215 ] = 60 ,
216+ time_spacing : Annotated [
217+ Optional [int ],
218+ typer .Option (
219+ "--time-spacing" ,
220+ help = "Spacing between ticks on the time axis in months" ,
221+ ),
222+ ] = None ,
198223 variant_col : Annotated [
199224 str ,
200225 typer .Option (
@@ -222,16 +247,32 @@ def infer(
222247 Optional [str ],
223248 typer .Option ("--matplotlib-backend" , help = "Matplotlib backend to use" ),
224249 ] = None ,
250+ overwrite_output : Annotated [
251+ bool ,
252+ typer .Option (
253+ "--overwrite-output" ,
254+ help = "Allows overwriting the output directory, if it already exists. Note: this may result in unintented loss of data." ,
255+ ),
256+ ] = False ,
225257) -> None :
226258 """Runs growth advantage inference."""
227259 _set_matplotlib_backend (matplotlib_backend )
228260
261+ # Ignore warnings with JAX converting arrays from 64-bit to 32-bit
262+ warnings .filterwarnings (
263+ "ignore" ,
264+ message = r"Explicitly requested dtype float64 requested in zeros.*" ,
265+ category = UserWarning ,
266+ )
267+
229268 if var is None and config is None :
230269 raise ValueError (
231270 "The variant names are not specified. Use `--config` argument or `-v` to specify them."
232271 )
233272
234- config : Config = _parse_config (config_path = config , variants = var )
273+ config : Config = _parse_config (
274+ config_path = config , variants = var , time_spacing = time_spacing
275+ )
235276
236277 variants_investigated = config .variants
237278
@@ -248,7 +289,7 @@ def infer(
248289 )
249290
250291 output = Path (output )
251- output .mkdir (parents = True , exist_ok = False )
292+ output .mkdir (parents = True , exist_ok = overwrite_output )
252293
253294 def pprint (message ):
254295 with open (output / "log.txt" , "a" ) as file :
@@ -329,14 +370,27 @@ def pprint(message):
329370 theta_star , standard_errors_estimates , confidence_level = 0.95
330371 )
331372
332- pprint ("\n \n Relative growth advantages:" )
373+ pprint ("\n \n Relative growth advantages (per day):" )
374+ for variant , m , low , up in zip (
375+ variants_effective [1 :],
376+ qm .get_relative_growths (theta_star , n_variants = n_variants_effective ),
377+ qm .get_relative_growths (confints_estimates [0 ], n_variants = n_variants_effective ),
378+ qm .get_relative_growths (confints_estimates [1 ], n_variants = n_variants_effective ),
379+ ):
380+ pprint (
381+ f" { variant } : { float (m )/ time_scaler .time_unit :.4f} ({ float (low ) / time_scaler .time_unit :.4f} – { float (up ) / time_scaler .time_unit :.4f} )"
382+ )
383+
384+ pprint ("\n \n Relative growth advantages (per week):" )
333385 for variant , m , low , up in zip (
334386 variants_effective [1 :],
335387 qm .get_relative_growths (theta_star , n_variants = n_variants_effective ),
336388 qm .get_relative_growths (confints_estimates [0 ], n_variants = n_variants_effective ),
337389 qm .get_relative_growths (confints_estimates [1 ], n_variants = n_variants_effective ),
338390 ):
339- pprint (f" { variant } : { float (m ):.2f} ({ float (low ):.2f} – { float (up ):.2f} )" )
391+ pprint (
392+ f" { variant } : { DAYS_IN_A_WEEK * float (m )/ time_scaler .time_unit :.4f} ({ DAYS_IN_A_WEEK * float (low ) / time_scaler .time_unit :.4f} – { DAYS_IN_A_WEEK * float (up ) / time_scaler .time_unit :.4f} )"
393+ )
340394
341395 # Generate predictions
342396 ys_fitted_confint = qm .get_confidence_bands_logit (
@@ -364,7 +418,6 @@ def pprint(message):
364418 )
365419
366420 # Create a plot
367-
368421 colors = [config .plot .variant_colors [var ] for var in variants_investigated ]
369422
370423 plot_dimensions = config .plot .dimensions
@@ -378,6 +431,7 @@ def pprint(message):
378431 bottom = plot_dimensions .bottom ,
379432 left = plot_dimensions .left ,
380433 right = plot_dimensions .right ,
434+ sharex = True ,
381435 )
382436
383437 def plot_city (ax , i : int ) -> None :
@@ -432,7 +486,9 @@ def remove_0th(arr):
432486 alpha = 0.3 ,
433487 )
434488
435- adjust_axis_fn = plot_ts .AdjustXAxisForTime (start_date )
489+ adjust_axis_fn = plot_ts .AdjustXAxisForTime (
490+ start_date , spacing_months = config .plot .time_spacing
491+ )
436492 adjust_axis_fn (ax )
437493
438494 tick_positions = [0 , 0.5 , 1 ]
0 commit comments