From 710b4abdcfb5e0089c4b873842a849212f72beba Mon Sep 17 00:00:00 2001 From: marlinfiggins Date: Mon, 30 Sep 2024 11:40:35 -0700 Subject: [PATCH 1/2] Adding time-varying hierarchical model --- config/mlr-config.yaml | 1 + scripts/run-mlr-model.py | 29 ++++++++++++++++++++++------- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/config/mlr-config.yaml b/config/mlr-config.yaml index 3ded58b..666e2e0 100644 --- a/config/mlr-config.yaml +++ b/config/mlr-config.yaml @@ -15,6 +15,7 @@ model: generation_time: 4.8 pivot: "24A" hierarchical: true + time_varying: true inference: method: "NUTS" diff --git a/scripts/run-mlr-model.py b/scripts/run-mlr-model.py index 6facdfd..5419fef 100644 --- a/scripts/run-mlr-model.py +++ b/scripts/run-mlr-model.py @@ -86,7 +86,7 @@ def load_data(self, override_seq_path=None): return raw_seq, locations - def load_model(self, override_hier=None): + def load_model(self, override_hier=None, override_time_varying=None): model_cf = self.config["model"] # Processing generation time @@ -96,17 +96,25 @@ def load_model(self, override_hier=None): if override_hier is not None: hier = override_hier + time_varying = parse_with_default(model_cf, "time_varying", dflt=False) + if override_time_varying is not None: + time_varying = override_time_varying + print("hierarchical:", hier) + print("time varying:", time_varying) # Processing likelihoods if hier: ps = parse_pool_scale(model_cf) print("Hierarchical pool scale:", ps) - model = ef.HierMLR(tau=tau, pool_scale=ps) + if time_varying: + model = ef.HierMLRTimeVarying(tau=tau) # TODO: Add options + else: + model = ef.HierMLR(tau=tau, pool_scale=ps) else: model = ef.MultinomialLogisticRegression(tau=tau) model.forecast_L = forecast_L - return model, hier + return model, hier, time_varying def load_optim(self): infer_cf = self.config["inference"] @@ -276,9 +284,9 @@ def make_raw_freq_tidy(data, location): return {"metadata": metadata, "data": entries} -def export_results(multi_posterior, ps, path, data_name, hier): +def export_results(multi_posterior, ps, path, data_name, hier, time_varying): EXPORT_SITES = ["freq", "ga", "freq_forecast"] - EXPORT_DATED = [True, False, True] + EXPORT_DATED = [True, time_varying, True] EXPORT_FORECASTS = [False, False, True] EXPORT_ATTRS = ["pivot"] @@ -383,6 +391,13 @@ def get_group_samples(samples, sites, group): help="Whether to run the model as hierarchical. Overrides model.hierarchical in config. " + "Default is false if unspecified." ) + + parser.add_argument( + "--time-varying", action='store_true', default=False, + help="Whether to run the model as time-varaying." + + "Default is false if unspecified." + ) + args = parser.parse_args() # Load configuration, data, and create model @@ -396,7 +411,7 @@ def get_group_samples(samples, sites, group): if args.hier: override_hier = args.hier - mlr_model, hier = config.load_model(override_hier=override_hier) + mlr_model, hier, time_varying = config.load_model(override_hier=override_hier) print("Model created.") inference_method = config.load_optim() @@ -452,4 +467,4 @@ def get_group_samples(samples, sites, group): config.config["settings"], "ps", dflt=[0.5, 0.8, 0.95] ) data_name = args.data_name or config.config["data"]["name"] - export_results(multi_posterior, ps, export_path, data_name, hier) + export_results(multi_posterior, ps, export_path, data_name, hier, time_varying) From 8f01d0c8d7053159df136bc372ef4284bfa7f4ff Mon Sep 17 00:00:00 2001 From: marlinfiggins Date: Mon, 30 Sep 2024 11:43:43 -0700 Subject: [PATCH 2/2] Adding override option for time-varying --- scripts/run-mlr-model.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/scripts/run-mlr-model.py b/scripts/run-mlr-model.py index 5419fef..82ec9b0 100644 --- a/scripts/run-mlr-model.py +++ b/scripts/run-mlr-model.py @@ -2,13 +2,15 @@ # coding: utf-8 import argparse +import json +import os +from datetime import date + +import evofr as ef import numpy as np import pandas as pd -import os import yaml -import json -import evofr as ef -from datetime import date + def parse_with_default(cf, var, dflt): if var in cf: @@ -394,7 +396,7 @@ def get_group_samples(samples, sites, group): parser.add_argument( "--time-varying", action='store_true', default=False, - help="Whether to run the model as time-varaying." + help="Whether to run the model as time-varaying. Overrides model.time_varying in config." + "Default is false if unspecified." ) @@ -411,7 +413,11 @@ def get_group_samples(samples, sites, group): if args.hier: override_hier = args.hier - mlr_model, hier, time_varying = config.load_model(override_hier=override_hier) + override_time_varying = None + if args.time_varying: + override_time_varying = args.time_varying + + mlr_model, hier, time_varying = config.load_model(override_hier=override_hier, override_time_varying=override_time_varying) print("Model created.") inference_method = config.load_optim()