Skip to content

Commit 6d1acd1

Browse files
authored
Minor adjustments to the plots (#60)
* Minor adjustments * Control spacing between time ticks * Ignore numerical type casting errors.
1 parent 19868fc commit 6d1acd1

File tree

2 files changed

+88
-22
lines changed

2 files changed

+88
-22
lines changed

src/covvfit/_cli/infer.py

Lines changed: 71 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Script running Covvfit inference on the data."""
2+
import warnings
23
from pathlib import Path
34
from typing import Annotated, NamedTuple, Optional
45

@@ -117,31 +118,45 @@ class PlotDimensions(pydantic.BaseModel):
117118
panel_height: float = 1.5
118119
dpi: int = 350
119120

120-
wspace: float = 1.0
121-
hspace: float = 0.5
121+
wspace: float = pydantic.Field(
122+
default=1.0, help="Horizontal (width) spacing between figure panels."
123+
)
124+
hspace: float = pydantic.Field(
125+
default=0.5, help="Vertical (height) spacing between figure panels."
126+
)
122127

123-
left: float = 1.0
124-
right: float = 1.5
125-
top: float = 0.7
126-
bottom: float = 0.5
128+
left: float = pydantic.Field(default=1.0, help="Left margin in the figure.")
129+
right: float = pydantic.Field(default=1.5, help="Right margin in the figure.")
130+
top: float = pydantic.Field(default=0.7, help="Top margin in the figure.")
131+
bottom: float = pydantic.Field(default=0.5, help="Bottom margin in the figure.")
127132

128133

129134
class PlotSettings(pydantic.BaseModel):
130135
dimensions: PlotDimensions = pydantic.Field(default_factory=PlotDimensions)
131136
prediction: PredictionRegion = pydantic.Field(default_factory=PredictionRegion)
132137
variant_colors: dict[str, str] = pydantic.Field(
133-
default_factory=lambda: plot_ts.COLORS_COVSPECTRUM
138+
default_factory=lambda: plot_ts.COLORS_COVSPECTRUM,
139+
help="Dictionary mapping variants to colors in the plot.",
140+
)
141+
time_spacing: pydantic.conint(ge=1) = pydantic.Field(
142+
default=1, help="Spacing between ticks on the time axis (in months)."
134143
)
135144

136145

137146
class Config(pydantic.BaseModel):
138-
variants: list[str] = pydantic.Field(default_factory=lambda: [])
139-
plot: PlotSettings = pydantic.Field(default_factory=PlotSettings)
147+
variants: list[str] = pydantic.Field(
148+
default_factory=lambda: [],
149+
help="List of variants to be included in the analysis.",
150+
)
151+
plot: PlotSettings = pydantic.Field(
152+
default_factory=PlotSettings, help="Plot settings."
153+
)
140154

141155

142156
def _parse_config(
143157
config_path: Optional[str],
144158
variants: Optional[list[str]],
159+
time_spacing: Optional[int],
145160
) -> Config:
146161
if config_path is None:
147162
config = Config()
@@ -153,6 +168,9 @@ def _parse_config(
153168
if variants is not None:
154169
config.variants = variants
155170

171+
if time_spacing is not None:
172+
config.plot.time_spacing = time_spacing
173+
156174
if len(config.variants) == 0:
157175
raise ValueError("No variants have been specified.")
158176

@@ -195,6 +213,13 @@ def infer(
195213
help="Number of future days for which abundance prediction should be generated",
196214
),
197215
] = 60,
216+
time_spacing: Annotated[
217+
Optional[int],
218+
typer.Option(
219+
"--time-spacing",
220+
help="Spacing between ticks on the time axis in months",
221+
),
222+
] = None,
198223
variant_col: Annotated[
199224
str,
200225
typer.Option(
@@ -222,16 +247,32 @@ def infer(
222247
Optional[str],
223248
typer.Option("--matplotlib-backend", help="Matplotlib backend to use"),
224249
] = None,
250+
overwrite_output: Annotated[
251+
bool,
252+
typer.Option(
253+
"--overwrite-output",
254+
help="Allows overwriting the output directory, if it already exists. Note: this may result in unintented loss of data.",
255+
),
256+
] = False,
225257
) -> None:
226258
"""Runs growth advantage inference."""
227259
_set_matplotlib_backend(matplotlib_backend)
228260

261+
# Ignore warnings with JAX converting arrays from 64-bit to 32-bit
262+
warnings.filterwarnings(
263+
"ignore",
264+
message=r"Explicitly requested dtype float64 requested in zeros.*",
265+
category=UserWarning,
266+
)
267+
229268
if var is None and config is None:
230269
raise ValueError(
231270
"The variant names are not specified. Use `--config` argument or `-v` to specify them."
232271
)
233272

234-
config: Config = _parse_config(config_path=config, variants=var)
273+
config: Config = _parse_config(
274+
config_path=config, variants=var, time_spacing=time_spacing
275+
)
235276

236277
variants_investigated = config.variants
237278

@@ -248,7 +289,7 @@ def infer(
248289
)
249290

250291
output = Path(output)
251-
output.mkdir(parents=True, exist_ok=False)
292+
output.mkdir(parents=True, exist_ok=overwrite_output)
252293

253294
def pprint(message):
254295
with open(output / "log.txt", "a") as file:
@@ -329,14 +370,27 @@ def pprint(message):
329370
theta_star, standard_errors_estimates, confidence_level=0.95
330371
)
331372

332-
pprint("\n\nRelative growth advantages:")
373+
pprint("\n\nRelative growth advantages (per day):")
374+
for variant, m, low, up in zip(
375+
variants_effective[1:],
376+
qm.get_relative_growths(theta_star, n_variants=n_variants_effective),
377+
qm.get_relative_growths(confints_estimates[0], n_variants=n_variants_effective),
378+
qm.get_relative_growths(confints_estimates[1], n_variants=n_variants_effective),
379+
):
380+
pprint(
381+
f" {variant}: {float(m)/ time_scaler.time_unit :.4f} ({float(low) / time_scaler.time_unit:.4f}{float(up) / time_scaler.time_unit :.4f})"
382+
)
383+
384+
pprint("\n\nRelative growth advantages (per week):")
333385
for variant, m, low, up in zip(
334386
variants_effective[1:],
335387
qm.get_relative_growths(theta_star, n_variants=n_variants_effective),
336388
qm.get_relative_growths(confints_estimates[0], n_variants=n_variants_effective),
337389
qm.get_relative_growths(confints_estimates[1], n_variants=n_variants_effective),
338390
):
339-
pprint(f" {variant}: {float(m):.2f} ({float(low):.2f}{float(up):.2f})")
391+
pprint(
392+
f" {variant}: {DAYS_IN_A_WEEK * float(m)/ time_scaler.time_unit :.4f} ({DAYS_IN_A_WEEK * float(low) / time_scaler.time_unit:.4f}{DAYS_IN_A_WEEK * float(up) / time_scaler.time_unit :.4f})"
393+
)
340394

341395
# Generate predictions
342396
ys_fitted_confint = qm.get_confidence_bands_logit(
@@ -364,7 +418,6 @@ def pprint(message):
364418
)
365419

366420
# Create a plot
367-
368421
colors = [config.plot.variant_colors[var] for var in variants_investigated]
369422

370423
plot_dimensions = config.plot.dimensions
@@ -378,6 +431,7 @@ def pprint(message):
378431
bottom=plot_dimensions.bottom,
379432
left=plot_dimensions.left,
380433
right=plot_dimensions.right,
434+
sharex=True,
381435
)
382436

383437
def plot_city(ax, i: int) -> None:
@@ -432,7 +486,9 @@ def remove_0th(arr):
432486
alpha=0.3,
433487
)
434488

435-
adjust_axis_fn = plot_ts.AdjustXAxisForTime(start_date)
489+
adjust_axis_fn = plot_ts.AdjustXAxisForTime(
490+
start_date, spacing_months=config.plot.time_spacing
491+
)
436492
adjust_axis_fn(ax)
437493

438494
tick_positions = [0, 0.5, 1]

src/covvfit/plotting/_timeseries.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
"BA.2.86": "#FF20E0",
3232
# TODO(Pawel, David): Use consistent colors with Covspectrum
3333
"JN.1": "#00e9ff", # improv
34-
"KP.2": "#D16666", # improv
35-
"KP.3": "#66A366", # improv
36-
"XEC": "#A366A3", # improv
34+
"KP.2": "#876566",
35+
"KP.3": "#331eee",
36+
"XEC": "#a2a626",
3737
"undetermined": "#969696",
3838
}
3939

@@ -72,7 +72,7 @@ class _MonthStartLocator(ticker.Locator):
7272
the first day of each month within the data's visible range.
7373
"""
7474

75-
def __init__(self, start_date: str, time_unit: str) -> None:
75+
def __init__(self, start_date: str, time_unit: str, spacing_months: int) -> None:
7676
"""
7777
7878
Args:
@@ -83,6 +83,9 @@ def __init__(self, start_date: str, time_unit: str) -> None:
8383
# Store the reference start_date as a Timestamp
8484
self.start_date = pd.to_datetime(start_date)
8585
self.time_unit = time_unit
86+
self.spacing_months = spacing_months
87+
if spacing_months <= 0:
88+
raise ValueError("Has to be at least 1.")
8689

8790
# See the todo item in the `__call__` method
8891
if time_unit != "D":
@@ -119,7 +122,7 @@ def __call__(self):
119122
ticks.append(offset_days)
120123
current += pd.offsets.MonthBegin(1)
121124

122-
return ticks
125+
return ticks[:: self.spacing_months]
123126

124127
def tick_values(self, vmin, vmax):
125128
# Matplotlib may call tick_values directly; just reuse __call__()
@@ -131,8 +134,9 @@ def __init__(
131134
self,
132135
time0: str,
133136
*,
134-
fmt="%b. '%y",
137+
fmt="%b '%y",
135138
time_unit: str = "D",
139+
spacing_months: int = 1,
136140
) -> None:
137141
"""Adjusts the X ticks, so that the ticks
138142
are placed at the first day of each month.
@@ -146,6 +150,7 @@ def __init__(
146150
self.start_date = time0
147151
self.fmt = fmt
148152
self.time_unit = time_unit
153+
self.spacing_months = spacing_months
149154

150155
def _num_to_date(self, num: pd.Series | Float[Array, " timepoints"]) -> pd.Series:
151156
"""Convert days number into a date format"""
@@ -154,7 +159,11 @@ def _num_to_date(self, num: pd.Series | Float[Array, " timepoints"]) -> pd.Serie
154159

155160
def __call__(self, ax: plt.Axes) -> None:
156161
ax.xaxis.set_major_locator(
157-
_MonthStartLocator(start_date=self.start_date, time_unit=self.time_unit)
162+
_MonthStartLocator(
163+
start_date=self.start_date,
164+
time_unit=self.time_unit,
165+
spacing_months=self.spacing_months,
166+
)
158167
)
159168
ax.xaxis.set_major_formatter(
160169
ticker.FuncFormatter(lambda x, pos: self._num_to_date(x))
@@ -308,4 +317,5 @@ def plot_confidence_bands(
308317
alpha=alpha,
309318
label=label,
310319
**kwargs,
320+
edgecolor=None,
311321
)

0 commit comments

Comments
 (0)