From e618f5d6deae48a1b2166e6e9bf950fd7143d66f Mon Sep 17 00:00:00 2001 From: pciturri Date: Tue, 29 Apr 2025 11:46:53 +0200 Subject: [PATCH 1/5] refac: changed forecast registry abstraction. i) renamed to model registry for clarity (ii) Registries' main job is to provide global access to Object keys (model forecasts, input catalog and arguments file), build the key structure (for instance, the path tree) and check if these objects already exists (iii) current (ModelFileRegistry) and future (ModelHDF5Registry, ModelSQLRegistry) concrete classes are abstracted from ModelRegistry, which has the defined interface. --- floatcsep/experiment.py | 2 +- floatcsep/infrastructure/registries.py | 311 ++++++++++++++-------- floatcsep/infrastructure/repositories.py | 20 +- floatcsep/model.py | 26 +- tests/integration/test_model_accessors.py | 10 +- tests/integration/test_model_docker.py | 10 +- tests/unit/test_experiment.py | 1 + tests/unit/test_model.py | 18 +- tests/unit/test_registry.py | 62 ++--- tests/unit/test_repositories.py | 12 +- 10 files changed, 278 insertions(+), 194 deletions(-) diff --git a/floatcsep/experiment.py b/floatcsep/experiment.py index da2930a..92216cc 100644 --- a/floatcsep/experiment.py +++ b/floatcsep/experiment.py @@ -175,7 +175,7 @@ def __getattr__(self, item: str) -> object: Override built-in method to return the experiment attributes by also using the command ``experiment.{attr}``. Adds also to the experiment scope the keys of :attr:`region_config` or :attr:`time_config`. These are: ``start_date``, ``end_date``, - ``timewindows``, ``horizon``, ``offset``, ``region``, ``magnitudes``, ``mag_min``, + ``time_windows``, ``horizon``, ``offset``, ``region``, ``magnitudes``, ``mag_min``, `mag_max``, ``mag_bin``, ``depth_min`` depth_max . """ diff --git a/floatcsep/infrastructure/registries.py b/floatcsep/infrastructure/registries.py index 66d2d52..db655de 100644 --- a/floatcsep/infrastructure/registries.py +++ b/floatcsep/infrastructure/registries.py @@ -14,78 +14,22 @@ log = logging.getLogger("floatLogger") -class FileRegistry(ABC): - - def __init__(self, workdir: str) -> None: - self.workdir = workdir - - @staticmethod - def _parse_arg(arg) -> Union[str, list[str]]: - if isinstance(arg, (list, tuple)): - return timewindow2str(arg) - elif isinstance(arg, str): - return arg - elif hasattr(arg, "name"): - return arg.name - elif hasattr(arg, "__name__"): - return arg.__name__ - else: - raise Exception("Arg is not found") - +class ModelRegistry(ABC): @abstractmethod - def as_dict(self) -> dict: + def get_input_catalog_key(self, tstring: str) -> str: pass @abstractmethod - def build_tree(self, *args, **kwargs) -> None: + def get_forecast_key(self, tstring: str) -> str: pass @abstractmethod - def get(self, *args: Sequence[str]) -> Any: + def get_args_key(self, tstring: str) -> str: pass - def abs(self, *paths: Sequence[str]) -> str: - _path = normpath(abspath(join(self.workdir, *paths))) - return _path - - def abs_dir(self, *paths: Sequence[str]) -> str: - _path = normpath(abspath(join(self.workdir, *paths))) - _dir = dirname(_path) - return _dir - - def rel(self, *paths: Sequence[str]) -> str: - """Gets the relative path of a file, when it was defined relative to. - - the experiment working dir. - """ - _abspath = normpath(abspath(join(self.workdir, *paths))) - _relpath = relpath(_abspath, self.workdir) - return _relpath - - def rel_dir(self, *paths: Sequence[str]) -> str: - """Gets the absolute path of a file, when it was defined relative to. - - the experiment working dir. - """ - - _path = normpath(abspath(join(self.workdir, *paths))) - _dir = dirname(_path) - - return relpath(_dir, self.workdir) - - def file_exists(self, *args: Sequence[str]): - file_abspath = self.get(*args) - return exists(file_abspath) - - -class ForecastRegistry(FileRegistry): - """ - The class has the responsibility of managing the keys (based on timewindow strings) and path - structure of the forecast pertaining to a model (i.e., forecasts from different - time-windows), keeping track of the forecast existence and path in the filesystem. - """ +class ModelFileRegistry(ModelRegistry): def __init__( self, workdir: str, @@ -104,55 +48,25 @@ def __init__( args_file (str): The path of the arguments file (only for TimeDependentModel). input_cat (str): : The path of the arguments file (only for TimeDependentModel). """ - super().__init__(workdir) + self.workdir = workdir self.path = path self.database = database self.args_file = args_file self.input_cat = input_cat self.forecasts = {} - self._fmt = fmt - def get(self, *args: Sequence[str]) -> str: - """ - Args: - *args: A sequence of keys (usually time-window strings) - - Returns: - The registry element (forecast, catalogs, etc.) from a sequence of key value - (usually time-window strings) - """ - - val = self.__dict__ - for i in args: - parsed_arg = self._parse_arg(i) - val = val[parsed_arg] - return self.abs(val) - - def get_forecast(self, *args: Sequence[str]) -> str: - """ - Gets the filepath of a forecast for a given sequence of keys (usually a timewindow - string). - - Args: - *args: A sequence of keys (usually time-window strings) - - Returns: - The forecast registry from a sequence of key values - """ - return self.get("forecasts", *args) - @property def dir(self) -> str: """ Returns: The directory containing the model source. """ - if os.path.isdir(self.get("path")): - return self.get("path") + if os.path.isdir(self.get_attr("path")): + return self.get_attr("path") else: - return os.path.dirname(self.get("path")) + return os.path.dirname(self.get_attr("path")) @property def fmt(self) -> str: @@ -169,25 +83,73 @@ def fmt(self) -> str: return ext else: return self._fmt - - def as_dict(self) -> dict: + @staticmethod + def _parse_arg(arg) -> Union[str, list[str]]: + if isinstance(arg, (list, tuple)): + return timewindow2str(arg) + elif isinstance(arg, str): + return arg + elif hasattr(arg, "name"): + return arg.name + elif hasattr(arg, "__name__"): + return arg.__name__ + else: + raise Exception("Arg is not found") + + def get_attr(self, *args: Sequence[str]) -> str: """ + Args: + *args: A sequence of keys (usually time-window strings) Returns: - Simple dictionary serialization of the instance with the core attributes + The registry element (forecast, catalogs, etc.) from a sequence of key value + (usually time-window strings) """ - return { - "workdir": self.workdir, - "path": self.path, - "database": self.database, - "args_file": self.args_file, - "input_cat": self.input_cat, - "forecasts": self.forecasts, - } + + val = self.__dict__ + for i in args: + parsed_arg = self._parse_arg(i) + val = val[parsed_arg] + return self.abs(val) + + def abs(self, *paths: Sequence[str]) -> str: + + _path = normpath(abspath(join(self.workdir, *paths))) + return _path + + def abs_dir(self, *paths: Sequence[str]) -> str: + _path = normpath(abspath(join(self.workdir, *paths))) + _dir = dirname(_path) + return _dir + + def rel(self, *paths: Sequence[str]) -> str: + """Gets the relative path of a file, when it was defined relative to. + + the experiment working dir. + """ + + _abspath = normpath(abspath(join(self.workdir, *paths))) + _relpath = relpath(_abspath, self.workdir) + return _relpath + + def rel_dir(self, *paths: Sequence[str]) -> str: + """Gets the absolute path of a file, when it was defined relative to. + + the experiment working dir. + """ + + _path = normpath(abspath(join(self.workdir, *paths))) + _dir = dirname(_path) + + return relpath(_dir, self.workdir) + + def file_exists(self, *args: Sequence[str]): + file_abspath = self.get_attr(*args) + return exists(file_abspath) def forecast_exists(self, timewindow: Union[str, list]) -> Union[bool, Sequence[bool]]: """ - Checks if forecasts exist for a sequence of timewindows + Checks if forecasts exist for a sequence of time_windows Args: timewindow (str, list): A single or sequence of strings representing a time window @@ -200,9 +162,48 @@ def forecast_exists(self, timewindow: Union[str, list]) -> Union[bool, Sequence[ else: return [self.file_exists("forecasts", i) for i in timewindow] + def get_input_catalog_key(self, *args: Sequence[str]) -> str: + """ + Gets the filepath of the input catalog for a given sequence of keys (usually a timewindow + string). + + Args: + *args: A sequence of keys (usually time-window strings) + + Returns: + The input catalog registry key from a sequence of key values + """ + return self.get_attr("input_cat", *args) + + def get_forecast_key(self, *args: Sequence[str]) -> str: + """ + Gets the filepath of a forecast for a given sequence of keys (usually a timewindow + string). + + Args: + *args: A sequence of keys (usually time-window strings) + + Returns: + The forecast registry from a sequence of key values + """ + return self.get_attr("forecasts", *args) + + def get_args_key(self, *args: Sequence[str]) -> str: + """ + Gets the filepath of an arguments file for a given sequence of keys (usually a timewindow + string). + + Args: + *args: A sequence of keys (usually time-window strings) + + Returns: + The argument file's key(s) from a sequence of key values + """ + return self.get_attr("args_file", *args) + def build_tree( self, - timewindows: Sequence[Sequence[datetime]] = None, + time_windows: Sequence[Sequence[datetime]] = None, model_class: str = "TimeIndependentModel", prefix: str = None, args_file: str = None, @@ -212,7 +213,7 @@ def build_tree( Creates the run directory, and reads the file structure inside. Args: - timewindows (list(str)): List of time windows or strings. + time_windows (list(str)): List of time windows or strings. model_class (str): Model's class name prefix (str): prefix of the model forecast filenames if TD args_file (str, bool): input arguments path of the model if TD @@ -221,7 +222,7 @@ def build_tree( """ - windows = timewindow2str(timewindows) + windows = timewindow2str(time_windows) if model_class == "TimeIndependentModel": fname = self.database if self.database else self.path @@ -246,6 +247,21 @@ def build_tree( win: join(dirtree["forecasts"], f"{prefix}_{win}.{self.fmt}") for win in windows } + def as_dict(self) -> dict: + """ + + Returns: + Simple dictionary serialization of the instance with the core attributes + """ + return { + "workdir": self.workdir, + "path": self.path, + "database": self.database, + "args_file": self.args_file, + "input_cat": self.input_cat, + "forecasts": self.forecasts, + } + def log_tree(self) -> None: """ Logs a grouped summary of the forecasts' dictionary. @@ -266,6 +282,73 @@ def log_tree(self) -> None: log.debug(f" Time Window: {timewindow}") + + +class FileRegistry(ABC): + + def __init__(self, workdir: str) -> None: + self.workdir = workdir + + @staticmethod + def _parse_arg(arg) -> Union[str, list[str]]: + if isinstance(arg, (list, tuple)): + return timewindow2str(arg) + elif isinstance(arg, str): + return arg + elif hasattr(arg, "name"): + return arg.name + elif hasattr(arg, "__name__"): + return arg.__name__ + else: + raise Exception("Arg is not found") + + @abstractmethod + def as_dict(self) -> dict: + pass + + @abstractmethod + def build_tree(self, *args, **kwargs) -> None: + pass + + @abstractmethod + def get(self, *args: Sequence[str]) -> Any: + pass + + def abs(self, *paths: Sequence[str]) -> str: + _path = normpath(abspath(join(self.workdir, *paths))) + return _path + + def abs_dir(self, *paths: Sequence[str]) -> str: + _path = normpath(abspath(join(self.workdir, *paths))) + _dir = dirname(_path) + return _dir + + def rel(self, *paths: Sequence[str]) -> str: + """Gets the relative path of a file, when it was defined relative to. + + the experiment working dir. + """ + + _abspath = normpath(abspath(join(self.workdir, *paths))) + _relpath = relpath(_abspath, self.workdir) + return _relpath + + def rel_dir(self, *paths: Sequence[str]) -> str: + """Gets the absolute path of a file, when it was defined relative to. + + the experiment working dir. + """ + + _path = normpath(abspath(join(self.workdir, *paths))) + _dir = dirname(_path) + + return relpath(_dir, self.workdir) + + def file_exists(self, *args: Sequence[str]): + file_abspath = self.get(*args) + return exists(file_abspath) + + class ExperimentRegistry(FileRegistry): """ The class has the responsibility of managing the keys (based on models, timewindow and @@ -308,7 +391,7 @@ def get_forecast_registry(self, model_name: str) -> None: model_name (str): The name of the model. Returns: - ForecastRegistry: The ForecastRegistry associated with the model. + ModelRegistry: The ModelRegistry associated with the model. """ return self.forecast_registries.get(model_name) @@ -404,7 +487,7 @@ def as_dict(self) -> str: def build_tree( self, - timewindows: Sequence[Sequence[datetime]], + time_windows: Sequence[Sequence[datetime]], models: Sequence["Model"], tests: Sequence["Evaluation"], ) -> None: @@ -412,12 +495,12 @@ def build_tree( Creates the run directory and reads the file structure inside. Args: - timewindows: List of time windows, or representing string. + time_windows: List of time windows, or representing string. models: List of models or model names tests: List of tests or test names """ - windows = timewindow2str(timewindows) + windows = timewindow2str(time_windows) models = [i.name for i in models] tests = [i.name for i in tests] diff --git a/floatcsep/infrastructure/repositories.py b/floatcsep/infrastructure/repositories.py index c485d26..94781fc 100644 --- a/floatcsep/infrastructure/repositories.py +++ b/floatcsep/infrastructure/repositories.py @@ -13,7 +13,7 @@ from csep.utils.time_utils import decimal_year from floatcsep.utils.readers import ForecastParsers -from floatcsep.infrastructure.registries import ForecastRegistry, ExperimentRegistry +from floatcsep.infrastructure.registries import ExperimentRegistry, ModelRegistry from floatcsep.utils.helpers import str2timewindow, parse_csep_func from floatcsep.utils.helpers import timewindow2str @@ -27,7 +27,7 @@ class ForecastRepository(ABC): @abstractmethod - def __init__(self, registry: ForecastRegistry): + def __init__(self, registry: ModelRegistry): self.registry = registry self.lazy_load = False self.forecasts = {} @@ -61,7 +61,7 @@ def __eq__(self, other) -> bool: @classmethod def factory( - cls, registry: ForecastRegistry, model_class: str, forecast_type: str = None, **kwargs + cls, registry: ModelRegistry, model_class: str, forecast_type: str = None, **kwargs ) -> "ForecastRepository": """Factory method. Instantiate first on explicit option provided in the model configuration. Then, defaults to gridded forecast for TimeIndependentModel and catalog @@ -89,11 +89,11 @@ class CatalogForecastRepository(ForecastRepository): """ - def __init__(self, registry: ForecastRegistry, **kwargs): + def __init__(self, registry: ModelRegistry, **kwargs): """ Args: - registry (ForecastRegistry): The registry containing the keys/path to the forecasts + registry (ModelRegistry): The registry containing the keys/path to the forecasts given their time-windows. **kwargs: """ @@ -120,7 +120,7 @@ def load_forecast( return [self._load_single_forecast(t, region) for t in tstring] def _load_single_forecast(self, t: str, region=None): - fc_path = self.registry.get_forecast(t) + fc_path = self.registry.get_forecast_key(t) return csep.load_catalog_forecast( fc_path, region=region, apply_filters=True, filter_spatial=True ) @@ -136,11 +136,11 @@ class GriddedForecastRepository(ForecastRepository): avoid parsing files repeatedly (Skip for large files). """ - def __init__(self, registry: ForecastRegistry, **kwargs): + def __init__(self, registry: ModelRegistry, **kwargs): """ Args: - registry (ForecastRegistry): The registry containing the keys/path to the forecasts + registry (ModelRegistry): The registry containing the keys/path to the forecasts given their time-windows. **kwargs: """ @@ -189,7 +189,7 @@ def _load_single_forecast(self, tstring: str, fc_unit: float = 1, name_=""): time_horizon = decimal_year(end_date) - decimal_year(start_date) tstring_ = timewindow2str([start_date, end_date]) - f_path = self.registry.get_forecast(tstring_) + f_path = self.registry.get_forecast_key(tstring_) f_parser = getattr(ForecastParsers, self.registry.fmt) rates, region, mags = f_parser(f_path) @@ -504,4 +504,4 @@ def set_input_cat(self, tstring: str, model: "Model") -> None: """ start, end = str2timewindow(tstring) sub_cat = self.catalog.filter([f"origin_time < {start.timestamp() * 1000}"]) - sub_cat.write_ascii(filename=model.registry.get("input_cat")) + sub_cat.write_ascii(filename=model.registry.get_input_catalog_key()) diff --git a/floatcsep/model.py b/floatcsep/model.py index 0993a55..e223e49 100644 --- a/floatcsep/model.py +++ b/floatcsep/model.py @@ -12,7 +12,7 @@ from floatcsep.utils.accessors import from_zenodo, from_git from floatcsep.infrastructure.environments import EnvironmentFactory from floatcsep.utils.readers import ForecastParsers, HDF5Serializer -from floatcsep.infrastructure.registries import ForecastRegistry +from floatcsep.infrastructure.registries import ModelFileRegistry from floatcsep.infrastructure.repositories import ForecastRepository from floatcsep.utils.helpers import timewindow2str, str2timewindow, parse_nested_dicts @@ -116,7 +116,7 @@ def get_source(self, zenodo_id: int = None, giturl: str = None, **kwargs) -> Non raise FileNotFoundError("Model has no path or identified") if not os.path.exists(self.registry.dir) or not os.path.exists( - self.registry.get("path") + self.registry.get_attr("path") ): raise FileNotFoundError( f"Directory '{self.registry.dir}' or file {self.registry}' do not exist. " @@ -210,7 +210,7 @@ def __init__(self, name: str, model_path: str, forecast_unit=1, store_db=False, self.forecast_unit = forecast_unit self.store_db = store_db - self.registry = ForecastRegistry(kwargs.get("workdir", os.getcwd()), model_path) + self.registry = ModelFileRegistry(kwargs.get("workdir", os.getcwd()), model_path) # todo: Set factory for registry. self.repository = ForecastRepository.factory( self.registry, model_class=self.__class__.__name__, **kwargs ) @@ -231,7 +231,7 @@ def stage(self, timewindows: Sequence[Sequence[datetime]] = None) -> None: if self.store_db: self.init_db() - self.registry.build_tree(timewindows=timewindows, model_class=self.__class__.__name__) + self.registry.build_tree(time_windows=timewindows, model_class=self.__class__.__name__) def init_db(self, dbpath: str = "", force: bool = False) -> None: """ @@ -247,7 +247,7 @@ def init_db(self, dbpath: str = "", force: bool = False) -> None: """ parser = getattr(ForecastParsers, self.registry.fmt) - rates, region, mag = parser(self.registry.get("path")) + rates, region, mag = parser(self.registry.get_attr("path")) db_func = HDF5Serializer.grid2hdf5 if not dbpath: @@ -320,9 +320,9 @@ def __init__( self.func = func self.func_kwargs = func_kwargs or {} - self.registry = ForecastRegistry(workdir=kwargs.get("workdir", os.getcwd()), - path=model_path, - fmt=fmt) + self.registry = ModelFileRegistry(workdir=kwargs.get("workdir", os.getcwd()), + path=model_path, + fmt=fmt) # todo: Set Factory for Registry self.repository = ForecastRepository.factory( self.registry, model_class=self.__class__.__name__, **kwargs ) @@ -351,7 +351,7 @@ def stage(self, timewindows=None) -> None: self.environment.create_environment(force=self.force_build) self.registry.build_tree( - timewindows=timewindows, + time_windows=timewindows, model_class=self.__class__.__name__, prefix=self.__dict__.get("prefix", self.name), args_file=self.__dict__.get("args_file", None), @@ -366,7 +366,7 @@ def get_forecast( Note: The argument ``tstring`` is formatted according to how the Experiment - handles timewindows, specified in the functions + handles time_windows, specified in the functions :func:`~floatcsep.utils.helpers.timewindow2str` and :func:`~floatcsep.utils.helpers.str2timewindow` @@ -385,7 +385,7 @@ def create_forecast(self, tstring: str, **kwargs) -> None: Note: The argument ``tstring`` is formatted according to how the Experiment - handles timewindows, specified in the functions + handles time_windows, specified in the functions :func:`~floatcsep.utils.helpers.timewindow2str` and :func:`~floatcsep.utils.helpers.str2timewindow` @@ -406,7 +406,7 @@ def create_forecast(self, tstring: str, **kwargs) -> None: f"Running {self.name} using {self.environment.__class__.__name__}:" f" {timewindow2str([start_date, end_date])}" ) - self.environment.run_command(f"{self.func} {self.registry.get('args_file')}") + self.environment.run_command(f"{self.func} {self.registry.get_args_key()}") def prepare_args(self, start: datetime, end: datetime, **kwargs) -> None: """ @@ -421,7 +421,7 @@ def prepare_args(self, start: datetime, end: datetime, **kwargs) -> None: **kwargs: represents additional model arguments (name/value pair) """ - filepath = self.registry.get("args_file") + filepath = self.registry.get_args_key() fmt = os.path.splitext(filepath)[1] if fmt == ".txt": diff --git a/tests/integration/test_model_accessors.py b/tests/integration/test_model_accessors.py index 5c940f9..bf27695 100644 --- a/tests/integration/test_model_accessors.py +++ b/tests/integration/test_model_accessors.py @@ -113,7 +113,7 @@ def init_model(name, model_path, **kwargs): return model @patch.object(EnvironmentManager, "create_environment") - @patch("floatcsep.infrastructure.registries.ForecastRegistry.build_tree") + @patch("floatcsep.infrastructure.registries.ModelFileRegistry.build_tree") def test_from_git(self, mock_build_tree, mock_create_environment): """clones model from git, checks with test artifacts""" mock_build_tree.return_value = None @@ -169,7 +169,7 @@ def init_model(name, model_path, **kwargs): model = TimeIndependentModel(name=name, model_path=model_path, **kwargs) return model - @patch("floatcsep.infrastructure.registries.ForecastRegistry.build_tree") + @patch("floatcsep.infrastructure.registries.ModelFileRegistry.build_tree") def test_zenodo(self, mock_buildtree): """downloads model from zenodo, checks with test artifacts""" mock_buildtree.return_value = None @@ -194,11 +194,11 @@ def test_zenodo(self, mock_buildtree): model_b.stage() self.assertEqual( - os.path.basename(model_a.registry.get("path")), - os.path.basename(model_b.registry.get("path")), + os.path.basename(model_a.registry.get_attr("path")), + os.path.basename(model_b.registry.get_attr("path")), ) self.assertEqual(model_a.name, model_b.name) - self.assertTrue(filecmp.cmp(model_a.registry.get("path"), model_b.registry.get("path"))) + self.assertTrue(filecmp.cmp(model_a.registry.get_attr("path"), model_b.registry.get_attr("path"))) def test_zenodo_fail(self): name = "mock_zenodo" diff --git a/tests/integration/test_model_docker.py b/tests/integration/test_model_docker.py index beefcb9..4efe165 100644 --- a/tests/integration/test_model_docker.py +++ b/tests/integration/test_model_docker.py @@ -78,20 +78,20 @@ def _make_model(self, subfolder: str, tag: str): workdir=str(model_dir), ) - @patch("floatcsep.infrastructure.registries.ForecastRegistry.build_tree") + @patch("floatcsep.infrastructure.registries.ModelFileRegistry.build_tree") def test_valid_model(self, mock_registry): model = self._make_model("valid", "testdocker_valid") model.stage() model.environment.run_command() # Should succeed with no exceptions - @patch("floatcsep.infrastructure.registries.ForecastRegistry.build_tree") + @patch("floatcsep.infrastructure.registries.ModelFileRegistry.build_tree") def test_invalid_image_build_fails(self, mock_registry): model = self._make_model("invalid_image", "testdocker_invalid_image") with self.assertRaises(RuntimeError) as err: model.environment.create_environment(force=True) self.assertIn("Docker build error", str(err.exception)) - @patch("floatcsep.infrastructure.registries.ForecastRegistry.build_tree") + @patch("floatcsep.infrastructure.registries.ModelFileRegistry.build_tree") def test_invalid_entrypoint_fails_to_run(self, mock_registry): model = self._make_model("invalid_entrypoint", "testdocker_invalid_entrypoint") model.stage() @@ -99,7 +99,7 @@ def test_invalid_entrypoint_fails_to_run(self, mock_registry): model.environment.run_command() self.assertIn("exited with code", str(err.exception)) - @patch("floatcsep.infrastructure.registries.ForecastRegistry.build_tree") + @patch("floatcsep.infrastructure.registries.ModelFileRegistry.build_tree") def test_invalid_permission_fails_to_run(self, mock_registry): model = self._make_model("invalid_permission", "testdocker_invalid_permission") model.stage() @@ -107,7 +107,7 @@ def test_invalid_permission_fails_to_run(self, mock_registry): model.environment.run_command() self.assertIn("exited with code", str(err.exception)) - @patch("floatcsep.infrastructure.registries.ForecastRegistry.build_tree") + @patch("floatcsep.infrastructure.registries.ModelFileRegistry.build_tree") def test_valid_custom_uid_gid(self, mock_registry): # todo: look into it model = self._make_model("valid_custom_uid-gid", "testdocker_uid_gid") diff --git a/tests/unit/test_experiment.py b/tests/unit/test_experiment.py index cce5cdc..a739a41 100644 --- a/tests/unit/test_experiment.py +++ b/tests/unit/test_experiment.py @@ -41,6 +41,7 @@ def assertEqualExperiment(self, exp_a, exp_b): self.assertEqual(exp_a.registry.workdir, os.getcwd()) self.assertEqual(exp_a.registry.workdir, exp_b.registry.workdir) self.assertEqual(exp_a.start_date, exp_b.start_date) + print(exp_a.timewindows, exp_b.timewindows) self.assertEqual(exp_a.timewindows, exp_b.timewindows) self.assertEqual(exp_a.exp_class, exp_b.exp_class) self.assertEqual(exp_a.region, exp_b.region) diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index 35bd2a4..66a5c3d 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -2,7 +2,7 @@ from unittest import TestCase from floatcsep.model import TimeIndependentModel -from floatcsep.infrastructure.registries import ForecastRegistry +from floatcsep.infrastructure.registries import ModelRegistry from floatcsep.infrastructure.repositories import GriddedForecastRepository from unittest.mock import patch, MagicMock, mock_open from floatcsep.model import TimeDependentModel @@ -27,7 +27,7 @@ def assertEqualModel(model_a, model_b): raise AssertionError("Models are not equal") for i in keys_a: - if isinstance(getattr(model_a, i), ForecastRegistry): + if isinstance(getattr(model_a, i), ModelRegistry): continue if not (getattr(model_a, i) == getattr(model_b, i)): print(getattr(model_a, i), getattr(model_b, i)) @@ -59,7 +59,7 @@ def test_from_filesystem(self): @patch("os.makedirs") @patch("floatcsep.model.TimeIndependentModel.get_source") - @patch("floatcsep.infrastructure.registries.ForecastRegistry.build_tree") + @patch("floatcsep.infrastructure.registries.ModelFileRegistry.build_tree") def test_stage_creates_directory(self, mock_build_tree, mock_get_source, mock_makedirs): """Test stage method creates directory.""" model = self.init_model("mock", "mockfile.csv") @@ -159,7 +159,7 @@ class TestTimeDependentModel(TestModel): def setUp(self): # Patches - self.patcher_registry = patch("floatcsep.model.ForecastRegistry") + self.patcher_registry = patch("floatcsep.model.ModelFileRegistry") self.patcher_repository = patch("floatcsep.model.ForecastRepository.factory") self.patcher_environment = patch("floatcsep.model.EnvironmentFactory.get_env") self.patcher_get_source = patch( @@ -185,7 +185,7 @@ def setUp(self): # Set attributes on the mock objects self.mock_registry_instance.workdir = "/path/to/workdir" self.mock_registry_instance.path = "/path/to/model" - self.mock_registry_instance.get.return_value = ( + self.mock_registry_instance.get_args_key.return_value = ( "/path/to/args_file.txt" # Mocking the return of the registry call ) @@ -230,7 +230,7 @@ def test_stage(self, mk): self.model.zenodo_id, self.model.giturl, branch=self.model.repo_hash ) self.mock_registry_instance.build_tree.assert_called_once_with( - timewindows=["2020-01-01_2020-12-31"], + time_windows=["2020-01-01_2020-12-31"], model_class="TimeDependentModel", prefix=self.model.__dict__.get("prefix", self.name), args_file=self.model.__dict__.get("args_file", None), @@ -254,7 +254,7 @@ def test_create_forecast(self, prep_args_mock): self.model.create_forecast(tstring, force=True) self.mock_environment_instance.run_command.assert_called_once_with( - f'{self.func} {self.model.registry.get("args_file")}' + f'{self.func} {self.model.registry.get_args_key()}' ) @patch("builtins.open", new_callable=mock_open) @@ -279,7 +279,7 @@ def test_prepare_args(self, mock_json_dump, mock_json_load, mock_open_file): ] # Call the method - args_file_path = self.model.registry.get("args_file") + args_file_path = self.model.registry.get_args_key() self.model.prepare_args(start_date, end_date, custom_arg="value") mock_open_file.assert_any_call(args_file_path, "r") mock_open_file.assert_any_call(args_file_path, "w") @@ -293,7 +293,7 @@ def test_prepare_args(self, mock_json_dump, mock_json_load, mock_open_file): ) json_file_path = "/path/to/args_file.json" - self.model.registry.get.return_value = json_file_path + self.model.registry.get_args_key.return_value = json_file_path self.model.prepare_args(start_date, end_date, custom_arg="value") mock_open_file.assert_any_call(json_file_path, "r") diff --git a/tests/unit/test_registry.py b/tests/unit/test_registry.py index 369081f..5529a05 100644 --- a/tests/unit/test_registry.py +++ b/tests/unit/test_registry.py @@ -1,45 +1,45 @@ import unittest from datetime import datetime from unittest.mock import patch, MagicMock -from floatcsep.infrastructure.registries import ForecastRegistry +from floatcsep.infrastructure.registries import ModelFileRegistry -class TestForecastRegistry(unittest.TestCase): +class TestModelFileRegistry(unittest.TestCase): def setUp(self): - self.registry_file = ForecastRegistry( + self.registry_for_filebased_model = ModelFileRegistry( workdir="/test/workdir", path="/test/workdir/model.txt" ) - self.registry_folder = ForecastRegistry( + self.registry_for_folderbased_model = ModelFileRegistry( workdir="/test/workdir", path="/test/workdir/model" ) def test_call(self): - self.registry_file._parse_arg = MagicMock(return_value="path") - result = self.registry_file.get("path") + self.registry_for_filebased_model._parse_arg = MagicMock(return_value="path") + result = self.registry_for_filebased_model.get_attr("path") self.assertEqual(result, "/test/workdir/model.txt") @patch("os.path.isdir") def test_dir(self, mock_isdir): mock_isdir.return_value = False - self.assertEqual(self.registry_file.dir, "/test/workdir") + self.assertEqual(self.registry_for_filebased_model.dir, "/test/workdir") mock_isdir.return_value = True - self.assertEqual(self.registry_folder.dir, "/test/workdir/model") + self.assertEqual(self.registry_for_folderbased_model.dir, "/test/workdir/model") def test_fmt(self): - self.registry_file.database = "test.db" - self.assertEqual(self.registry_file.fmt, "db") - self.registry_file.database = None - self.assertEqual(self.registry_file.fmt, "txt") + self.registry_for_filebased_model.database = "test.db" + self.assertEqual(self.registry_for_filebased_model.fmt, "db") + self.registry_for_filebased_model.database = None + self.assertEqual(self.registry_for_filebased_model.fmt, "txt") def test_parse_arg(self): - self.assertEqual(self.registry_file._parse_arg("arg"), "arg") - self.assertRaises(Exception, self.registry_file._parse_arg, 123) + self.assertEqual(self.registry_for_filebased_model._parse_arg("arg"), "arg") + self.assertRaises(Exception, self.registry_for_filebased_model._parse_arg, 123) def test_as_dict(self): self.assertEqual( - self.registry_file.as_dict(), + self.registry_for_filebased_model.as_dict(), { "args_file": None, "database": None, @@ -51,28 +51,28 @@ def test_as_dict(self): ) def test_abs(self): - result = self.registry_file.abs("file.txt") + result = self.registry_for_filebased_model.abs("file.txt") self.assertTrue(result.endswith("/test/workdir/file.txt")) - def test_absdir(self): - result = self.registry_file.abs_dir("model.txt") + def test_abs_dir(self): + result = self.registry_for_filebased_model.abs_dir("model.txt") self.assertTrue(result.endswith("/test/workdir")) @patch("floatcsep.infrastructure.registries.exists") - def test_fileexists(self, mock_exists): + def test_file_exists(self, mock_exists): mock_exists.return_value = True - self.registry_file.get = MagicMock(return_value="/test/path/file.txt") - self.assertTrue(self.registry_file.file_exists("file.txt")) + self.registry_for_filebased_model.get_attr = MagicMock(return_value="/test/path/file.txt") + self.assertTrue(self.registry_for_filebased_model.file_exists("file.txt")) @patch("os.makedirs") @patch("os.listdir") def test_build_tree_time_independent(self, mock_listdir, mock_makedirs): timewindows = [[datetime(2023, 1, 1), datetime(2023, 1, 2)]] - self.registry_file.build_tree( - timewindows=timewindows, model_class="TimeIndependentModel" + self.registry_for_filebased_model.build_tree( + time_windows=timewindows, model_class="TimeIndependentModel" ) - self.assertIn("2023-01-01_2023-01-02", self.registry_file.forecasts) - # self.assertIn("2023-01-01_2023-01-02", self.registry_file.inventory) + self.assertIn("2023-01-01_2023-01-02", self.registry_for_filebased_model.forecasts) + # self.assertIn("2023-01-01_2023-01-02", self.registry_for_filebased_model.inventory) @patch("os.makedirs") @patch("os.listdir") @@ -82,13 +82,13 @@ def test_build_tree_time_dependent(self, mock_listdir, mock_makedirs): [datetime(2023, 1, 1), datetime(2023, 1, 2)], [datetime(2023, 1, 2), datetime(2023, 1, 3)], ] - self.registry_folder.build_tree( - timewindows=timewindows, model_class="TimeDependentModel", prefix="forecast" + self.registry_for_folderbased_model.build_tree( + time_windows=timewindows, model_class="TimeDependentModel", prefix="forecast" ) - self.assertIn("2023-01-01_2023-01-02", self.registry_folder.forecasts) - # self.assertTrue(self.registry_folder.inventory["2023-01-01_2023-01-02"]) - self.assertIn("2023-01-02_2023-01-03", self.registry_folder.forecasts) - # self.assertTrue(self.registry_folder.inventory["2023-01-02_2023-01-03"]) + self.assertIn("2023-01-01_2023-01-02", self.registry_for_folderbased_model.forecasts) + # self.assertTrue(self.registry_for_folderbased_model.inventory["2023-01-01_2023-01-02"]) + self.assertIn("2023-01-02_2023-01-03", self.registry_for_folderbased_model.forecasts) + # self.assertTrue(self.registry_for_folderbased_model.inventory["2023-01-02_2023-01-03"]) if __name__ == "__main__": diff --git a/tests/unit/test_repositories.py b/tests/unit/test_repositories.py index 1cab55a..c0f567b 100644 --- a/tests/unit/test_repositories.py +++ b/tests/unit/test_repositories.py @@ -5,7 +5,7 @@ from csep.core.forecasts import GriddedForecast from floatcsep.utils.readers import ForecastParsers -from floatcsep.infrastructure.registries import ForecastRegistry +from floatcsep.infrastructure.registries import ModelFileRegistry from floatcsep.infrastructure.repositories import ( CatalogForecastRepository, GriddedForecastRepository, @@ -17,7 +17,7 @@ class TestCatalogForecastRepository(unittest.TestCase): def setUp(self): - self.registry = MagicMock(spec=ForecastRegistry) + self.registry = MagicMock(spec=ModelFileRegistry) #todo: Factory registry self.registry.__call__ = MagicMock(return_value="a_duck") @patch("csep.load_catalog_forecast") @@ -48,7 +48,7 @@ def test_load_single_forecast(self, mock_load_catalog_forecast): class TestGriddedForecastRepository(unittest.TestCase): def setUp(self): - self.registry = MagicMock(spec=ForecastRegistry) + self.registry = MagicMock(spec=ModelFileRegistry) #todo: Factory registry self.registry.fmt = "hdf5" self.registry.__call__ = MagicMock(return_value="a_duck") @@ -138,10 +138,10 @@ def test_lazy_load_behavior(self, mock_parser): self.assertEqual(forecast, "forecatto") self.assertNotIn("2023-01-02_2023-01-03", repo.forecasts) - @patch("floatcsep.infrastructure.registries.ForecastRegistry") - def test_equal(self, MockForecastRegistry): + @patch("floatcsep.infrastructure.registries.ModelFileRegistry") + def test_equal(self, MockModelFileRegistry): - self.registry = MockForecastRegistry() + self.registry = MockModelFileRegistry() self.repo1 = CatalogForecastRepository(self.registry) self.repo2 = CatalogForecastRepository(self.registry) From 39094ee1ea242a7de85dcb8ec8c166279f93fdf0 Mon Sep 17 00:00:00 2001 From: pciturri Date: Tue, 29 Apr 2025 11:59:21 +0200 Subject: [PATCH 2/5] refac: renamed timewindow to time_window across the package. renamed qa tests to e2e --- floatcsep/experiment.py | 26 ++++++------- floatcsep/infrastructure/registries.py | 4 +- floatcsep/infrastructure/repositories.py | 4 +- floatcsep/model.py | 12 +++--- floatcsep/postprocess/plot_handler.py | 10 ++--- floatcsep/postprocess/reporting.py | 2 +- floatcsep/utils/helpers.py | 16 ++++---- tests/{qa => e2e}/test_data.py | 0 .../integration/test_model_infrastructure.py | 16 ++++---- tests/unit/test_experiment.py | 4 +- tests/unit/test_model.py | 2 +- tests/unit/test_plot_handler.py | 2 +- tests/unit/test_registry.py | 8 ++-- tests/unit/test_utils.py | 38 +++++++++---------- 14 files changed, 72 insertions(+), 72 deletions(-) rename tests/{qa => e2e}/test_data.py (100%) diff --git a/floatcsep/experiment.py b/floatcsep/experiment.py index 92216cc..cb36d70 100644 --- a/floatcsep/experiment.py +++ b/floatcsep/experiment.py @@ -52,8 +52,8 @@ class Experiment: - growth (:class:`str`): `incremental` or `cumulative` - offset (:class:`float`): recurrence of forecast creation. - For further details, see :func:`~floatcsep.utils.timewindows_ti` - and :func:`~floatcsep.utils.timewindows_td` + For further details, see :func:`~floatcsep.utils.time_windows_ti` + and :func:`~floatcsep.utils.time_windows_td` region_config (dict): Contains all the spatial and magnitude specifications. It must contain the following keys: @@ -143,7 +143,7 @@ def __init__( log.info(f"Setting up experiment {self.name}:") log.info(f"\tStart: {self.start_date}") log.info(f"\tEnd: {self.end_date}") - log.info(f"\tTime windows: {len(self.timewindows)}") + log.info(f"\tTime windows: {len(self.time_windows)}") log.info(f"\tRegion: {self.region.name if self.region else None}") log.info( f"\tMagnitude range: [{numpy.min(self.magnitudes)}," @@ -295,7 +295,7 @@ def stage_models(self) -> None: """ log.info("Staging models") for i in self.models: - i.stage(self.timewindows) + i.stage(self.time_windows) self.registry.add_forecast_registry(i) def set_tests(self, test_config: Union[str, Dict, List]) -> list: @@ -376,17 +376,17 @@ def set_tasks(self) -> None: """ # Set the file path structure - self.registry.build_tree(self.timewindows, self.models, self.tests) + self.registry.build_tree(self.time_windows, self.models, self.tests) log.debug("Pre-run forecast summary") - self.registry.log_forecast_trees(self.timewindows) + self.registry.log_forecast_trees(self.time_windows) log.debug("Pre-run result summary") self.registry.log_results_tree() log.info("Setting up experiment's tasks") # Get the time windows strings - tw_strings = timewindow2str(self.timewindows) + tw_strings = timewindow2str(self.time_windows) # Prepare the testing catalogs task_graph = TaskGraph() @@ -481,7 +481,7 @@ def set_tasks(self) -> None: ) # Set up the Sequential_Comparative Scores elif test_k.type == "sequential_comparative": - tw_strs = timewindow2str(self.timewindows) + tw_strs = timewindow2str(self.time_windows) for model_j in self.models: task_k = Task( instance=test_k, @@ -504,7 +504,7 @@ def set_tasks(self) -> None: ) # Set up the Batch comparative Scores elif test_k.type == "batch": - time_str = timewindow2str(self.timewindows[-1]) + time_str = timewindow2str(self.time_windows[-1]) for model_j in self.models: task_k = Task( instance=test_k, @@ -540,7 +540,7 @@ def run(self) -> None: self.task_graph.run() log.info("Calculation completed") log.debug("Post-run forecast registry") - self.registry.log_forecast_trees(self.timewindows) + self.registry.log_forecast_trees(self.time_windows) log.debug("Post-run result summary") self.registry.log_results_tree() @@ -604,7 +604,7 @@ def as_dict(self, extra: Sequence = (), extended=False) -> dict: "time_config": { i: j for i, j in self.time_config.items() - if (i not in ("timewindows",) or extended) + if (i not in ("time_windows",) or extended) }, "region_config": { i: j @@ -731,7 +731,7 @@ def test_stat(test_orig, test_repr): def get_results(self): - win_orig = timewindow2str(self.original.timewindows) + win_orig = timewindow2str(self.original.time_windows) tests_orig = self.original.tests @@ -787,7 +787,7 @@ def get_hash(filename): def get_filecomp(self): - win_orig = timewindow2str(self.original.timewindows) + win_orig = timewindow2str(self.original.time_windows) tests_orig = self.original.tests diff --git a/floatcsep/infrastructure/registries.py b/floatcsep/infrastructure/registries.py index db655de..c5bea8e 100644 --- a/floatcsep/infrastructure/registries.py +++ b/floatcsep/infrastructure/registries.py @@ -395,12 +395,12 @@ def get_forecast_registry(self, model_name: str) -> None: """ return self.forecast_registries.get(model_name) - def log_forecast_trees(self, timewindows: list) -> None: + def log_forecast_trees(self, time_windows: list) -> None: """ Logs the forecasts for all models managed by this ExperimentRegistry. """ log.debug("===================") - log.debug(f" Total Time Windows: {len(timewindows)}") + log.debug(f" Total Time Windows: {len(time_windows)}") for model_name, registry in self.forecast_registries.items(): log.debug(f" Model: {model_name}") registry.log_tree() diff --git a/floatcsep/infrastructure/repositories.py b/floatcsep/infrastructure/repositories.py index 94781fc..b138814 100644 --- a/floatcsep/infrastructure/repositories.py +++ b/floatcsep/infrastructure/repositories.py @@ -381,8 +381,8 @@ def catalog(self) -> CSEPCatalog: if isfile(self.cat_path): return CSEPCatalog.load_json(self.cat_path) bounds = { - "start_time": min([item for sublist in self.timewindows for item in sublist]), - "end_time": max([item for sublist in self.timewindows for item in sublist]), + "start_time": min([item for sublist in self.time_windows for item in sublist]), + "end_time": max([item for sublist in self.time_windows for item in sublist]), "min_magnitude": self.magnitudes.min(), "max_depth": self.depths.max(), } diff --git a/floatcsep/model.py b/floatcsep/model.py index e223e49..c7300eb 100644 --- a/floatcsep/model.py +++ b/floatcsep/model.py @@ -60,7 +60,7 @@ def __init__( self.__dict__.update(**kwargs) @abstractmethod - def stage(self, timewindows=None) -> None: + def stage(self, time_windows=None) -> None: """Prepares the stage for a model run.""" pass @@ -215,13 +215,13 @@ def __init__(self, name: str, model_path: str, forecast_unit=1, store_db=False, self.registry, model_class=self.__class__.__name__, **kwargs ) - def stage(self, timewindows: Sequence[Sequence[datetime]] = None) -> None: + def stage(self, time_windows: Sequence[Sequence[datetime]] = None) -> None: """ Acquire the forecast data if it is not in the file system. Sets the paths internally (or database pointers) to the forecast data. Args: - timewindows (list): time_windows that the forecast data represents. + time_windows (list): time_windows that the forecast data represents. """ if self.force_stage or not self.registry.file_exists("path"): @@ -231,7 +231,7 @@ def stage(self, timewindows: Sequence[Sequence[datetime]] = None) -> None: if self.store_db: self.init_db() - self.registry.build_tree(time_windows=timewindows, model_class=self.__class__.__name__) + self.registry.build_tree(time_windows=time_windows, model_class=self.__class__.__name__) def init_db(self, dbpath: str = "", force: bool = False) -> None: """ @@ -333,7 +333,7 @@ def __init__( self.build, self.name, self.registry.abs(model_path) ) - def stage(self, timewindows=None) -> None: + def stage(self, time_windows=None) -> None: """ Core method to interface a model with the experiment. @@ -351,7 +351,7 @@ def stage(self, timewindows=None) -> None: self.environment.create_environment(force=self.force_build) self.registry.build_tree( - time_windows=timewindows, + time_windows=time_windows, model_class=self.__class__.__name__, prefix=self.__dict__.get("prefix", self.name), args_file=self.__dict__.get("args_file", None), diff --git a/floatcsep/postprocess/plot_handler.py b/floatcsep/postprocess/plot_handler.py index 18399d6..8b70532 100644 --- a/floatcsep/postprocess/plot_handler.py +++ b/floatcsep/postprocess/plot_handler.py @@ -27,10 +27,10 @@ def plot_results(experiment: "Experiment") -> None: """ log.info("Plotting evaluation results") - timewindows = timewindow2str(experiment.timewindows) + time_windows = timewindow2str(experiment.time_windows) for test in experiment.tests: - test.plot_results(timewindows, experiment.models, experiment.registry) + test.plot_results(time_windows, experiment.models, experiment.registry) def plot_forecasts(experiment: "Experiment") -> None: @@ -76,9 +76,9 @@ def plot_forecasts(experiment: "Experiment") -> None: # Get the time windows to be plotted. Defaults to only the last time window. time_windows = ( - timewindow2str(experiment.timewindows) + timewindow2str(experiment.time_windows) if plot_forecast_config.get("all_time_windows") - else [timewindow2str(experiment.timewindows[-1])] + else [timewindow2str(experiment.time_windows[-1])] ) # Get the projection of the plots @@ -177,7 +177,7 @@ def plot_catalogs(experiment: "Experiment") -> None: # If selected, plot the test catalogs for each of the time windows if plot_catalog_config.get("all_time_windows"): - for tw in experiment.timewindows: + for tw in experiment.time_windows: test_catalog = experiment.catalog_repo.get_test_cat(timewindow2str(tw)) if test_catalog.get_number_of_events() != 0: diff --git a/floatcsep/postprocess/reporting.py b/floatcsep/postprocess/reporting.py index 7ec0e5d..0f1d9f7 100644 --- a/floatcsep/postprocess/reporting.py +++ b/floatcsep/postprocess/reporting.py @@ -35,7 +35,7 @@ def generate_report(experiment, timewindow=-1): custom_report(report_function, experiment) return - timewindow = experiment.timewindows[timewindow] + timewindow = experiment.time_windows[timewindow] timestr = timewindow2str(timewindow) log.info(f"Saving report into {experiment.registry.run_dir}") diff --git a/floatcsep/utils/helpers.py b/floatcsep/utils/helpers.py index 868e865..3a1eda9 100644 --- a/floatcsep/utils/helpers.py +++ b/floatcsep/utils/helpers.py @@ -158,14 +158,14 @@ def read_time_cfg(time_config, **kwargs): if "offset" in time_config.keys(): time_config["offset"] = parse_timedelta_string(time_config["offset"]) - if not time_config.get("timewindows"): + if not time_config.get("time_windows"): if experiment_class == "ti": - time_config["timewindows"] = timewindows_ti(**time_config) + time_config["time_windows"] = time_windows_ti(**time_config) elif experiment_class == "td": - time_config["timewindows"] = timewindows_td(**time_config) + time_config["time_windows"] = time_windows_td(**time_config) else: - time_config["start_date"] = time_config["timewindows"][0][0] - time_config["end_date"] = time_config["timewindows"][-1][-1] + time_config["start_date"] = time_config["time_windows"][0][0] + time_config["end_date"] = time_config["time_windows"][-1][-1] return time_config @@ -242,7 +242,7 @@ def timewindow2str(datetimes: Sequence) -> Union[str, list[str]]: single timewindow or a list of time windows. Args: - datetimes: A sequence (of sequences) of datetimes, representing a list of timewindows + datetimes: A sequence (of sequences) of datetimes, representing a list of time_windows Returns: A sequence of strings for each time window @@ -279,7 +279,7 @@ def str2timewindow( return datetimes -def timewindows_ti( +def time_windows_ti( start_date=None, end_date=None, intervals=None, horizon=None, growth="incremental", **_ ): """ @@ -336,7 +336,7 @@ def timewindows_ti( return [(timelimits[0], i) for i in timelimits[1:]] -def timewindows_td( +def time_windows_td( start_date=None, end_date=None, timeintervals=None, timehorizon=None, timeoffset=None, **_ ): """ diff --git a/tests/qa/test_data.py b/tests/e2e/test_data.py similarity index 100% rename from tests/qa/test_data.py rename to tests/e2e/test_data.py diff --git a/tests/integration/test_model_infrastructure.py b/tests/integration/test_model_infrastructure.py index 96adb74..7dda39f 100644 --- a/tests/integration/test_model_infrastructure.py +++ b/tests/integration/test_model_infrastructure.py @@ -35,10 +35,10 @@ def setUp(self): ) def test_time_independent_model_stage(self): - timewindows = [ + time_windows = [ [datetime(2023, 1, 1), datetime(2023, 1, 2)], ] - self.time_independent_model.stage(timewindows=timewindows) + self.time_independent_model.stage(time_windows=time_windows) print("a", self.time_independent_model.registry.as_dict()) self.assertIn("2023-01-01_2023-01-02", self.time_independent_model.registry.forecasts) @@ -50,10 +50,10 @@ def test_time_independent_model_get_forecast(self): def test_time_independent_model_get_forecast_real(self): tstring = "2023-01-01_2023-01-02" - timewindows = [ + time_windows = [ [datetime(2023, 1, 1), datetime(2023, 1, 2)], ] - self.time_independent_model.stage(timewindows=timewindows) + self.time_independent_model.stage(time_windows=time_windows) forecast = self.time_independent_model.get_forecast(tstring) self.assertIsInstance(forecast, GriddedForecast) self.assertAlmostEqual(forecast.data[0, 0], 0.002739726027357392) # 1 / 365 days @@ -63,12 +63,12 @@ def test_time_independent_model_get_forecast_real(self): def test_time_dependent_model_stage(self, mock_venv, mock_conda): mock_venv.return_value = None mock_conda.return_value = None - timewindows = [ + time_windows = [ [datetime(2020, 1, 1), datetime(2020, 1, 2)], [datetime(2020, 1, 2), datetime(2020, 1, 3)], ] tstrings = ["2020-01-01_2020-01-02", "2020-01-02_2020-01-03"] - self.time_dependent_model.stage(timewindows=timewindows) + self.time_dependent_model.stage(time_windows=time_windows) self.assertIn(tstrings[0], self.time_dependent_model.registry.forecasts) self.assertIn(tstrings[1], self.time_dependent_model.registry.forecasts) @@ -78,11 +78,11 @@ def test_time_dependent_model_stage(self, mock_venv, mock_conda): def test_time_dependent_model_get_forecast(self, mock_venv, mock_conda): mock_venv.return_value = None mock_conda.return_value = None - timewindows = [ + time_windows = [ [datetime(2020, 1, 1), datetime(2020, 1, 2)], [datetime(2020, 1, 2), datetime(2020, 1, 3)], ] - self.time_dependent_model.stage(timewindows) + self.time_dependent_model.stage(time_windows) tstring = "2020-01-01_2020-01-02" forecast = self.time_dependent_model.get_forecast(tstring) self.assertIsNotNone(forecast) diff --git a/tests/unit/test_experiment.py b/tests/unit/test_experiment.py index a739a41..ed7d116 100644 --- a/tests/unit/test_experiment.py +++ b/tests/unit/test_experiment.py @@ -41,8 +41,8 @@ def assertEqualExperiment(self, exp_a, exp_b): self.assertEqual(exp_a.registry.workdir, os.getcwd()) self.assertEqual(exp_a.registry.workdir, exp_b.registry.workdir) self.assertEqual(exp_a.start_date, exp_b.start_date) - print(exp_a.timewindows, exp_b.timewindows) - self.assertEqual(exp_a.timewindows, exp_b.timewindows) + print(exp_a.time_windows, exp_b.time_windows) + self.assertEqual(exp_a.time_windows, exp_b.time_windows) self.assertEqual(exp_a.exp_class, exp_b.exp_class) self.assertEqual(exp_a.region, exp_b.region) numpy.testing.assert_equal(exp_a.magnitudes, exp_b.magnitudes) diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index 66a5c3d..15f00c3 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -224,7 +224,7 @@ def test_init(self): def test_stage(self, mk): self.model.force_stage = True # Force staging to occur - self.model.stage(timewindows=["2020-01-01_2020-12-31"]) + self.model.stage(time_windows=["2020-01-01_2020-12-31"]) self.mock_get_source.assert_called_once_with( self.model.zenodo_id, self.model.giturl, branch=self.model.repo_hash diff --git a/tests/unit/test_plot_handler.py b/tests/unit/test_plot_handler.py index d1fd852..ebabb3f 100644 --- a/tests/unit/test_plot_handler.py +++ b/tests/unit/test_plot_handler.py @@ -15,7 +15,7 @@ def test_plot_results(self, mock_timewindow2str, mock_savefig): plot_handler.plot_results(mock_experiment) - mock_timewindow2str.assert_called_once_with(mock_experiment.timewindows) + mock_timewindow2str.assert_called_once_with(mock_experiment.time_windows) mock_test.plot_results.assert_called_once_with( ["2021-01-01", "2021-12-31"], mock_experiment.models, mock_experiment.registry ) diff --git a/tests/unit/test_registry.py b/tests/unit/test_registry.py index 5529a05..03ba40e 100644 --- a/tests/unit/test_registry.py +++ b/tests/unit/test_registry.py @@ -67,9 +67,9 @@ def test_file_exists(self, mock_exists): @patch("os.makedirs") @patch("os.listdir") def test_build_tree_time_independent(self, mock_listdir, mock_makedirs): - timewindows = [[datetime(2023, 1, 1), datetime(2023, 1, 2)]] + time_windows = [[datetime(2023, 1, 1), datetime(2023, 1, 2)]] self.registry_for_filebased_model.build_tree( - time_windows=timewindows, model_class="TimeIndependentModel" + time_windows=time_windows, model_class="TimeIndependentModel" ) self.assertIn("2023-01-01_2023-01-02", self.registry_for_filebased_model.forecasts) # self.assertIn("2023-01-01_2023-01-02", self.registry_for_filebased_model.inventory) @@ -78,12 +78,12 @@ def test_build_tree_time_independent(self, mock_listdir, mock_makedirs): @patch("os.listdir") def test_build_tree_time_dependent(self, mock_listdir, mock_makedirs): mock_listdir.return_value = ["forecast_1.csv"] - timewindows = [ + time_windows = [ [datetime(2023, 1, 1), datetime(2023, 1, 2)], [datetime(2023, 1, 2), datetime(2023, 1, 3)], ] self.registry_for_folderbased_model.build_tree( - time_windows=timewindows, model_class="TimeDependentModel", prefix="forecast" + time_windows=time_windows, model_class="TimeDependentModel", prefix="forecast" ) self.assertIn("2023-01-01_2023-01-02", self.registry_for_folderbased_model.forecasts) # self.assertTrue(self.registry_for_folderbased_model.inventory["2023-01-01_2023-01-02"]) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 341339f..e4940a2 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -11,11 +11,11 @@ import floatcsep.utils.accessors from floatcsep.utils.helpers import ( parse_timedelta_string, - timewindows_ti, + time_windows_ti, read_time_cfg, read_region_cfg, parse_csep_func, - timewindows_td, + time_windows_td, ) root_dir = os.path.dirname(os.path.abspath(__file__)) @@ -58,19 +58,19 @@ def test_parse_time_window(self): dt = "1decade" self.assertRaises(ValueError, parse_timedelta_string, dt) - def test_timewindows_ti(self): + def test_time_windows_ti(self): start = datetime(2014, 1, 1) end = datetime(2022, 1, 1) - self.assertEqual(timewindows_ti(start_date=start, end_date=end), [(start, end)]) + self.assertEqual(time_windows_ti(start_date=start, end_date=end), [(start, end)]) t1 = [ (datetime(2014, 1, 1), datetime(2018, 1, 1)), (datetime(2018, 1, 1), datetime(2022, 1, 1)), ] - self.assertEqual(timewindows_ti(start_date=start, end_date=end, intervals=2), t1) - self.assertEqual(timewindows_ti(start_date=start, end_date=end, horizon="4-years"), t1) - self.assertEqual(timewindows_ti(start_date=start, intervals=2, horizon="4-years"), t1) + self.assertEqual(time_windows_ti(start_date=start, end_date=end, intervals=2), t1) + self.assertEqual(time_windows_ti(start_date=start, end_date=end, horizon="4-years"), t1) + self.assertEqual(time_windows_ti(start_date=start, intervals=2, horizon="4-years"), t1) t2 = [ (datetime(2014, 1, 1, 0, 0), datetime(2015, 2, 22, 10, 17, 8, 571428)), @@ -93,18 +93,18 @@ def test_timewindows_ti(self): ), (datetime(2020, 11, 9, 13, 42, 51, 428571), datetime(2022, 1, 1, 0, 0)), ] - self.assertEqual(timewindows_ti(start_date=start, end_date=end, intervals=7), t2) + self.assertEqual(time_windows_ti(start_date=start, end_date=end, intervals=7), t2) - def test_timewindows_td(self): + def test_time_windows_td(self): start = datetime(2010, 1, 1) end = datetime(2020, 1, 1) self.assertEqual( - timewindows_td(start_date=start, end_date=end, timeintervals=1), [(start, end)] + time_windows_td(start_date=start, end_date=end, timeintervals=1), [(start, end)] ) self.assertEqual( - timewindows_td(start_date=start, end_date=end, timeintervals=5), + time_windows_td(start_date=start, end_date=end, timeintervals=5), [ (datetime(2010, 1, 1, 0, 0), datetime(2012, 1, 1, 9, 36)), (datetime(2012, 1, 1, 9, 36), datetime(2013, 12, 31, 19, 12)), @@ -115,11 +115,11 @@ def test_timewindows_td(self): ) self.assertEqual( - timewindows_td(start_date=start, timeintervals=1, timehorizon="1-years"), + time_windows_td(start_date=start, timeintervals=1, timehorizon="1-years"), [(datetime(2010, 1, 1, 0, 0), datetime(2011, 1, 1, 0, 0))], ) self.assertEqual( - timewindows_td(start_date=start, timeintervals=5, timehorizon="5-days"), + time_windows_td(start_date=start, timeintervals=5, timehorizon="5-days"), [ (datetime(2010, 1, 1, 0, 0), datetime(2010, 1, 6, 0, 0)), (datetime(2010, 1, 6, 0, 0), datetime(2010, 1, 11, 0, 0)), @@ -129,19 +129,19 @@ def test_timewindows_td(self): ], ) self.assertEqual( - timewindows_td( + time_windows_td( start_date=start, end_date=end, timehorizon="10-years", timeoffset="10-years" ), [(datetime(2010, 1, 1, 0, 0), datetime(2020, 1, 1, 0, 0))], ) self.assertEqual( - timewindows_td( + time_windows_td( start_date=start, end_date=end, timehorizon="12-years", timeoffset="10-years" ), [(datetime(2010, 1, 1, 0, 0), datetime(2022, 1, 1, 0, 0))], ) self.assertEqual( - timewindows_td( + time_windows_td( start_date=start, end_date=end, timehorizon="5-years", timeoffset="5-years" ), [ @@ -150,7 +150,7 @@ def test_timewindows_td(self): ], ) self.assertEqual( - timewindows_td( + time_windows_td( start_date=start, end_date=end, timehorizon="5-years", timeoffset="3-years" ), [ @@ -160,7 +160,7 @@ def test_timewindows_td(self): ], ) self.assertEqual( - timewindows_td( + time_windows_td( start_date=start, end_date=datetime(2010, 2, 1), timehorizon="14-days", @@ -174,7 +174,7 @@ def test_timewindows_td(self): ], ) self.assertEqual( - timewindows_td( + time_windows_td( start_date=start, timeintervals=3, timehorizon="3-years", From 796a51e58c1ceed205171e3ab872369bf85c248e Mon Sep 17 00:00:00 2001 From: pciturri Date: Tue, 29 Apr 2025 13:58:02 +0200 Subject: [PATCH 3/5] ft: Model.__init__ instantiates now a ModelRegistry factory --- floatcsep/infrastructure/registries.py | 21 ++++++++++++++++++++- floatcsep/model.py | 9 +++++---- tests/unit/test_model.py | 8 ++++---- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/floatcsep/infrastructure/registries.py b/floatcsep/infrastructure/registries.py index c5bea8e..3d27d60 100644 --- a/floatcsep/infrastructure/registries.py +++ b/floatcsep/infrastructure/registries.py @@ -27,6 +27,17 @@ def get_forecast_key(self, tstring: str) -> str: def get_args_key(self, tstring: str) -> str: pass + @classmethod + def factory(cls, registry_type: str = 'file', **kwargs) -> "ModelRegistry": + """Factory method. Instantiate first on any explicit option provided in the model + configuration. + """ + + if registry_type == 'file': + return ModelFileRegistry(**kwargs) + + elif registry_type == 'hdf5': + return ModelHDF5Registry(**kwargs) class ModelFileRegistry(ModelRegistry): @@ -281,8 +292,16 @@ def log_tree(self) -> None: for timewindow in not_exist_group: log.debug(f" Time Window: {timewindow}") +class ModelHDF5Registry(ModelRegistry): - + def __init__(self, workdir: str, path: str): + pass + def get_input_catalog_key(self, tstring: str) -> str: + return '' + def get_forecast_key(self, tstring: str) -> str: + return '' + def get_args_key(self, tstring: str) -> str: + return '' class FileRegistry(ABC): diff --git a/floatcsep/model.py b/floatcsep/model.py index c7300eb..d5c48dc 100644 --- a/floatcsep/model.py +++ b/floatcsep/model.py @@ -12,7 +12,7 @@ from floatcsep.utils.accessors import from_zenodo, from_git from floatcsep.infrastructure.environments import EnvironmentFactory from floatcsep.utils.readers import ForecastParsers, HDF5Serializer -from floatcsep.infrastructure.registries import ModelFileRegistry +from floatcsep.infrastructure.registries import ModelRegistry from floatcsep.infrastructure.repositories import ForecastRepository from floatcsep.utils.helpers import timewindow2str, str2timewindow, parse_nested_dicts @@ -210,7 +210,8 @@ def __init__(self, name: str, model_path: str, forecast_unit=1, store_db=False, self.forecast_unit = forecast_unit self.store_db = store_db - self.registry = ModelFileRegistry(kwargs.get("workdir", os.getcwd()), model_path) # todo: Set factory for registry. + self.registry = ModelRegistry.factory(workdir=kwargs.get("workdir", os.getcwd()), + path=model_path) self.repository = ForecastRepository.factory( self.registry, model_class=self.__class__.__name__, **kwargs ) @@ -320,9 +321,9 @@ def __init__( self.func = func self.func_kwargs = func_kwargs or {} - self.registry = ModelFileRegistry(workdir=kwargs.get("workdir", os.getcwd()), + self.registry = ModelRegistry.factory(workdir=kwargs.get("workdir", os.getcwd()), path=model_path, - fmt=fmt) # todo: Set Factory for Registry + fmt=fmt) self.repository = ForecastRepository.factory( self.registry, model_class=self.__class__.__name__, **kwargs ) diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index 15f00c3..125294a 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -159,7 +159,7 @@ class TestTimeDependentModel(TestModel): def setUp(self): # Patches - self.patcher_registry = patch("floatcsep.model.ModelFileRegistry") + self.patcher_registry = patch("floatcsep.model.ModelRegistry.factory") self.patcher_repository = patch("floatcsep.model.ForecastRepository.factory") self.patcher_environment = patch("floatcsep.model.EnvironmentFactory.get_env") self.patcher_get_source = patch( @@ -167,14 +167,14 @@ def setUp(self): ) # Patch the get_source method on Model # Start patches - self.mock_registry = self.patcher_registry.start() + self.mock_registry_factory = self.patcher_registry.start() self.mock_repository_factory = self.patcher_repository.start() self.mock_environment = self.patcher_environment.start() self.mock_get_source = self.patcher_get_source.start() # Mock instances self.mock_registry_instance = MagicMock() - self.mock_registry.return_value = self.mock_registry_instance + self.mock_registry_factory.return_value = self.mock_registry_instance self.mock_repository_instance = MagicMock() self.mock_repository_factory.return_value = self.mock_repository_instance @@ -204,7 +204,7 @@ def tearDown(self): def test_init(self): # Assertions to check if the components were instantiated correctly - self.mock_registry.assert_called_once_with( + self.mock_registry_factory.assert_called_once_with( workdir=os.getcwd(), path=self.model_path, fmt='csv' ) # Ensure the registry is initialized correctly self.mock_repository_factory.assert_called_once_with( From 71f1ea993793c82d3cc07029a41fc6e8380bfa63 Mon Sep 17 00:00:00 2001 From: pciturri Date: Tue, 29 Apr 2025 16:13:13 +0200 Subject: [PATCH 4/5] refac: abstracted filepath management functionality as a Mixin to be used in file-base registries. Created ExperimentRegistry abstract and File-based concrete classes. Renamed registries' interface to add/get_*_key. abstracted registry logging to be a function in the logger module. tests: added unit tests for experiment file registry --- floatcsep/evaluation.py | 6 +- floatcsep/experiment.py | 23 +- floatcsep/infrastructure/logger.py | 71 ++++ floatcsep/infrastructure/registries.py | 409 ++++++++++------------- floatcsep/infrastructure/repositories.py | 6 +- floatcsep/postprocess/plot_handler.py | 10 +- floatcsep/postprocess/reporting.py | 8 +- tests/unit/test_plot_handler.py | 2 +- tests/unit/test_registry.py | 83 ++++- tests/unit/test_reporting.py | 2 +- tests/unit/test_repositories.py | 14 +- tutorials/case_h/custom_report.py | 6 +- 12 files changed, 369 insertions(+), 271 deletions(-) diff --git a/floatcsep/evaluation.py b/floatcsep/evaluation.py index 93c25ba..fef6f69 100644 --- a/floatcsep/evaluation.py +++ b/floatcsep/evaluation.py @@ -287,7 +287,7 @@ def plot_results( # Regular consistency/comparative test plots (e.g., many models) try: for time_str in timewindow: - fig_path = registry.get_figure(time_str, self.name) + fig_path = registry.get_figure_key(time_str, self.name) results = self.read_results(time_str, models) ax = func(results, plot_args=fargs, **fkwargs) if "code" in fargs: @@ -307,7 +307,7 @@ def plot_results( registry.figures[time_str][fig_name] = os.path.join( time_str, "figures", fig_name ) - fig_path = registry.get_figure(time_str, fig_name) + fig_path = registry.get_figure_key(time_str, fig_name) ax = func(result, plot_args=fargs, **fkwargs, show=False) if "code" in fargs: exec(fargs["code"]) @@ -318,7 +318,7 @@ def plot_results( pyplot.show() elif self.type in ["sequential", "sequential_comparative", "batch"]: - fig_path = registry.get_figure(timewindow[-1], self.name) + fig_path = registry.get_figure_key(timewindow[-1], self.name) results = self.read_results(timewindow[-1], models) ax = func(results, plot_args=fargs, **fkwargs) diff --git a/floatcsep/experiment.py b/floatcsep/experiment.py index cb36d70..12be522 100644 --- a/floatcsep/experiment.py +++ b/floatcsep/experiment.py @@ -25,6 +25,7 @@ parse_nested_dicts, ) from floatcsep.infrastructure.engine import Task, TaskGraph +from floatcsep.infrastructure.logger import log_models_tree, log_results_tree log = logging.getLogger("floatLogger") @@ -118,7 +119,7 @@ def __init__( os.makedirs(os.path.join(workdir, rundir), exist_ok=True) self.name = name if name else "floatingExp" - self.registry = ExperimentRegistry(workdir, rundir) + self.registry = ExperimentRegistry.factory(workdir=workdir, run_dir=rundir) self.results_repo = ResultsRepository(self.registry) self.catalog_repo = CatalogRepository(self.registry) @@ -296,7 +297,7 @@ def stage_models(self) -> None: log.info("Staging models") for i in self.models: i.stage(self.time_windows) - self.registry.add_forecast_registry(i) + self.registry.add_model_registry(i) def set_tests(self, test_config: Union[str, Dict, List]) -> list: """ @@ -379,9 +380,9 @@ def set_tasks(self) -> None: self.registry.build_tree(self.time_windows, self.models, self.tests) log.debug("Pre-run forecast summary") - self.registry.log_forecast_trees(self.time_windows) + log_models_tree(log, self.registry, self.time_windows) log.debug("Pre-run result summary") - self.registry.log_results_tree() + log_results_tree(log, self.registry) log.info("Setting up experiment's tasks") @@ -540,9 +541,9 @@ def run(self) -> None: self.task_graph.run() log.info("Calculation completed") log.debug("Post-run forecast registry") - self.registry.log_forecast_trees(self.time_windows) + log_models_tree(log, self.registry, self.time_windows) log.debug("Post-run result summary") - self.registry.log_results_tree() + log_results_tree(log, self.registry) def read_results(self, test: Evaluation, window: str) -> List: """ @@ -559,7 +560,7 @@ def make_repr(self) -> None: """ log.info("Creating reproducibility config file") - repr_config = self.registry.get("repr_config") + repr_config = self.registry.get_attr("repr_config") # Dropping region to results folder if it is a file region_path = self.region_config.get("path", False) @@ -801,8 +802,8 @@ def get_filecomp(self): for tw in win_orig: results[test.name][tw] = dict.fromkeys(models_orig) for model in models_orig: - orig_path = self.original.registry.get_result(tw, test, model) - repr_path = self.reproduced.registry.get_result(tw, test, model) + orig_path = self.original.registry.get_result_key(tw, test, model) + repr_path = self.reproduced.registry.get_result_key(tw, test, model) results[test.name][tw][model] = { "hash": (self.get_hash(orig_path) == self.get_hash(repr_path)), @@ -811,8 +812,8 @@ def get_filecomp(self): else: results[test.name] = dict.fromkeys(models_orig) for model in models_orig: - orig_path = self.original.registry.get_result(win_orig[-1], test, model) - repr_path = self.reproduced.registry.get_result(win_orig[-1], test, model) + orig_path = self.original.registry.get_result_key(win_orig[-1], test, model) + repr_path = self.reproduced.registry.get_result_key(win_orig[-1], test, model) results[test.name][model] = { "hash": (self.get_hash(orig_path) == self.get_hash(repr_path)), "byte2byte": filecmp.cmp(orig_path, repr_path), diff --git a/floatcsep/infrastructure/logger.py b/floatcsep/infrastructure/logger.py index e96fad9..0e71036 100644 --- a/floatcsep/infrastructure/logger.py +++ b/floatcsep/infrastructure/logger.py @@ -60,3 +60,74 @@ def set_console_log_level(log_level): for handler in logger.handlers: if isinstance(handler, logging.StreamHandler): handler.setLevel(log_level) + + + + +def log_models_tree(log, experiment_registry, time_windows): + """ + Logs the forecasts for all models managed by this ExperimentFileRegistry. + """ + log.debug("===================") + log.debug(f" Total Time Windows: {len(time_windows)}") + for model_name, registry in experiment_registry.model_registries.items(): + log.debug(f" Model: {model_name}") + exists_group = [] + not_exist_group = [] + + for timewindow, filepath in registry.forecasts.items(): + if registry.forecast_exists(timewindow): + exists_group.append(timewindow) + else: + not_exist_group.append(timewindow) + + log.debug(f" Existing forecasts: {len(exists_group)}") + log.debug(f" Missing forecasts: {len(not_exist_group)}") + for timewindow in not_exist_group: + log.debug(f" Time Window: {timewindow}") + log.debug("===================") + + +def log_results_tree(log, experiment_registry): + """ + Logs a summary of the results dictionary, sorted by test. + For each test and time window, it logs whether all models have results, + or if some results are missing, and specifies which models are missing. + """ + log.debug("===================") + + total_results = results_exist_count = results_not_exist_count = 0 + + # Get all unique test names and sort them + all_tests = sorted( + {test_name for tests in experiment_registry.results.values() for test_name in tests} + ) + + for test_name in all_tests: + log.debug(f"Test: {test_name}") + for timewindow, tests in experiment_registry.results.items(): + if test_name in tests: + models = tests[test_name] + missing_models = [] + + for model_name, result_path in models.items(): + total_results += 1 + result_full_path = experiment_registry.get_result_key(timewindow, test_name, model_name) + if os.path.exists(result_full_path): + results_exist_count += 1 + else: + results_not_exist_count += 1 + missing_models.append(model_name) + + if not missing_models: + log.debug(f" Time Window: {timewindow} - All models evaluated.") + else: + log.debug( + f" Time Window: {timewindow} - Missing results for models: " + f"{', '.join(missing_models)}" + ) + + log.debug(f"Total Results: {total_results}") + log.debug(f"Results that Exist: {results_exist_count}") + log.debug(f"Results that Do Not Exist: {results_not_exist_count}") + log.debug("===================") \ No newline at end of file diff --git a/floatcsep/infrastructure/registries.py b/floatcsep/infrastructure/registries.py index 3d27d60..3fcbfd0 100644 --- a/floatcsep/infrastructure/registries.py +++ b/floatcsep/infrastructure/registries.py @@ -14,6 +14,114 @@ log = logging.getLogger("floatLogger") +class FilepathMixin: + """ + Small mixin to provide filepath management functionality to Registries that uses files to + store objects + """ + workdir: str + + @staticmethod + def _parse_arg(arg) -> Union[str, list[str]]: + if isinstance(arg, (list, tuple)): + return timewindow2str(arg) + elif isinstance(arg, str): + return arg + elif hasattr(arg, "name"): + return arg.name + elif hasattr(arg, "__name__"): + return arg.__name__ + else: + raise Exception("Arg is not found") + + def get_attr(self, *args: Sequence[str]) -> str: + """ + Access instance attributes and its contents (e.g., through dict keys) recursively in a + normalized function call. Returns the expected absolute path of this element + + Args: + *args: A sequence of keys (usually time-window strings) + + Returns: + The registry element (forecast, catalogs, etc.) from a sequence of key value + (usually time-window strings) as filepath + """ + + val = self.__dict__ + for i in args: + parsed_arg = self._parse_arg(i) + val = val[parsed_arg] + return self.abs(val) + + def abs(self, *paths: Sequence[str]) -> str: + """ + Returns the absolute path of an object, relative to the Registry workdir. + + Args: + *paths: + + Returns: + + """ + _path = normpath(abspath(join(self.workdir, *paths))) + return _path + + def abs_dir(self, *paths: Sequence[str]) -> str: + """ + Returns the absolute path of the directory containing an item relative to the Registry + workdir. + Args: + *paths: sequence of keys (usually time-window strings) + + Returns: + String describing the absolute directory + """ + _path = normpath(abspath(join(self.workdir, *paths))) + _dir = dirname(_path) + return _dir + + def rel(self, *paths: Sequence[str]) -> str: + """ + Gets the relative path of an item, relative to the Registry workdir + + Args: + *paths: sequence of keys (usually time-window strings) + Returns: + String describing the relative path + """ + + _abspath = normpath(abspath(join(self.workdir, *paths))) + _relpath = relpath(_abspath, self.workdir) + return _relpath + + def rel_dir(self, *paths: Sequence[str]) -> str: + """ + Gets the relative path of the directory containing an item, relative to the Registry + workdir + + Args: + *paths: sequence of keys (usually time-window strings) + Returns: + String describing the relative path + """ + + _path = normpath(abspath(join(self.workdir, *paths))) + _dir = dirname(_path) + + return relpath(_dir, self.workdir) + + def file_exists(self, *args: Sequence[str]): + """ + Determine is such file exists in the filesystem + + Args: + *paths: sequence of keys (usually time-window strings) + Returns: + flag indicating if file exists + """ + file_abspath = self.get_attr(*args) + return exists(file_abspath) + class ModelRegistry(ABC): @abstractmethod def get_input_catalog_key(self, tstring: str) -> str: @@ -40,7 +148,7 @@ def factory(cls, registry_type: str = 'file', **kwargs) -> "ModelRegistry": return ModelHDF5Registry(**kwargs) -class ModelFileRegistry(ModelRegistry): +class ModelFileRegistry(ModelRegistry, FilepathMixin): def __init__( self, workdir: str, @@ -94,69 +202,6 @@ def fmt(self) -> str: return ext else: return self._fmt - @staticmethod - def _parse_arg(arg) -> Union[str, list[str]]: - if isinstance(arg, (list, tuple)): - return timewindow2str(arg) - elif isinstance(arg, str): - return arg - elif hasattr(arg, "name"): - return arg.name - elif hasattr(arg, "__name__"): - return arg.__name__ - else: - raise Exception("Arg is not found") - - def get_attr(self, *args: Sequence[str]) -> str: - """ - Args: - *args: A sequence of keys (usually time-window strings) - - Returns: - The registry element (forecast, catalogs, etc.) from a sequence of key value - (usually time-window strings) - """ - - val = self.__dict__ - for i in args: - parsed_arg = self._parse_arg(i) - val = val[parsed_arg] - return self.abs(val) - - def abs(self, *paths: Sequence[str]) -> str: - - _path = normpath(abspath(join(self.workdir, *paths))) - return _path - - def abs_dir(self, *paths: Sequence[str]) -> str: - _path = normpath(abspath(join(self.workdir, *paths))) - _dir = dirname(_path) - return _dir - - def rel(self, *paths: Sequence[str]) -> str: - """Gets the relative path of a file, when it was defined relative to. - - the experiment working dir. - """ - - _abspath = normpath(abspath(join(self.workdir, *paths))) - _relpath = relpath(_abspath, self.workdir) - return _relpath - - def rel_dir(self, *paths: Sequence[str]) -> str: - """Gets the absolute path of a file, when it was defined relative to. - - the experiment working dir. - """ - - _path = normpath(abspath(join(self.workdir, *paths))) - _dir = dirname(_path) - - return relpath(_dir, self.workdir) - - def file_exists(self, *args: Sequence[str]): - file_abspath = self.get_attr(*args) - return exists(file_abspath) def forecast_exists(self, timewindow: Union[str, list]) -> Union[bool, Sequence[bool]]: """ @@ -273,24 +318,6 @@ def as_dict(self) -> dict: "forecasts": self.forecasts, } - def log_tree(self) -> None: - """ - Logs a grouped summary of the forecasts' dictionary. - Groups time windows by whether the forecast exists or not. - """ - exists_group = [] - not_exist_group = [] - - for timewindow, filepath in self.forecasts.items(): - if self.forecast_exists(timewindow): - exists_group.append(timewindow) - else: - not_exist_group.append(timewindow) - - log.debug(f" Existing forecasts: {len(exists_group)}") - log.debug(f" Missing forecasts: {len(not_exist_group)}") - for timewindow in not_exist_group: - log.debug(f" Time Window: {timewindow}") class ModelHDF5Registry(ModelRegistry): @@ -303,72 +330,46 @@ def get_forecast_key(self, tstring: str) -> str: def get_args_key(self, tstring: str) -> str: return '' -class FileRegistry(ABC): - - def __init__(self, workdir: str) -> None: - self.workdir = workdir - - @staticmethod - def _parse_arg(arg) -> Union[str, list[str]]: - if isinstance(arg, (list, tuple)): - return timewindow2str(arg) - elif isinstance(arg, str): - return arg - elif hasattr(arg, "name"): - return arg.name - elif hasattr(arg, "__name__"): - return arg.__name__ - else: - raise Exception("Arg is not found") - +class ExperimentRegistry(ABC): @abstractmethod - def as_dict(self) -> dict: + def add_model_registry(self, model: "Model") -> None: pass @abstractmethod - def build_tree(self, *args, **kwargs) -> None: + def get_model_registry(self, model_name: str) -> ModelRegistry: pass @abstractmethod - def get(self, *args: Sequence[str]) -> Any: + def get_result_key(self, test_name: str, model_name: str, tstring: str) -> str: pass - def abs(self, *paths: Sequence[str]) -> str: - _path = normpath(abspath(join(self.workdir, *paths))) - return _path - - def abs_dir(self, *paths: Sequence[str]) -> str: - _path = normpath(abspath(join(self.workdir, *paths))) - _dir = dirname(_path) - return _dir - - def rel(self, *paths: Sequence[str]) -> str: - """Gets the relative path of a file, when it was defined relative to. - - the experiment working dir. - """ + @abstractmethod + def get_figure_key(self, test_name: str, model_name: str, tstring: str) -> str: + pass - _abspath = normpath(abspath(join(self.workdir, *paths))) - _relpath = relpath(_abspath, self.workdir) - return _relpath + @abstractmethod + def get_test_catalog_key(self, tstring: str) -> str: + pass - def rel_dir(self, *paths: Sequence[str]) -> str: - """Gets the absolute path of a file, when it was defined relative to. + @abstractmethod + def build_tree( + self, + time_windows: Sequence[Sequence[datetime]], + models: Sequence["Model"], + tests: Sequence["Evaluation"], + ) -> None: + pass - the experiment working dir. + @classmethod + def factory(cls, registry_type: str = 'file', **kwargs) -> "ExperimentRegistry": + """Factory method. Instantiate first on any explicit option provided in the experiment + configuration. """ - _path = normpath(abspath(join(self.workdir, *paths))) - _dir = dirname(_path) - - return relpath(_dir, self.workdir) - - def file_exists(self, *args: Sequence[str]): - file_abspath = self.get(*args) - return exists(file_abspath) - + if registry_type == 'file': + return ExperimentFileRegistry(**kwargs) -class ExperimentRegistry(FileRegistry): +class ExperimentFileRegistry(ExperimentRegistry, FilepathMixin): """ The class has the responsibility of managing the keys (based on models, timewindow and evaluation name strings) to the structure of the experiment inputs (catalogs, models etc) @@ -383,28 +384,43 @@ def __init__(self, workdir: str, run_dir: str = "results") -> None: workdir: The working directory for the experiment run-time. run_dir: The directory in which the results will be stored. """ - super().__init__(workdir) + self.workdir = workdir self.run_dir = run_dir self.results = {} self.test_catalogs = {} self.figures = {} self.repr_config = "repr_config.yml" - self.forecast_registries = {} + self.model_registries = {} - def add_forecast_registry(self, model: "Model") -> None: + def get_attr(self, *args: Any) -> str: """ - Adds a model's ForecastRegistry to the ExperimentRegistry. + Args: + *args: A sequence of keys (usually models, tests and/or time-window strings) + + Returns: + The filepath from a sequence of key values (usually models first, then time-window + strings) + """ + val = self.__dict__ + for i in args: + parsed_arg = self._parse_arg(i) + val = val[parsed_arg] + return self.abs(self.run_dir, val) + + def add_model_registry(self, model: "Model") -> None: + """ + Adds a model's ForecastRegistry to the ExperimentFileRegistry. Args: model (str): A Model object """ - self.forecast_registries[model.name] = model.registry + self.model_registries[model.name] = model.registry - def get_forecast_registry(self, model_name: str) -> None: + def get_model_registry(self, model_name: str) -> None: """ - Retrieves a model's ForecastRegistry from the ExperimentRegistry. + Retrieves a model's ForecastRegistry from the ExperimentFileRegistry. Args: model_name (str): The name of the model. @@ -412,67 +428,53 @@ def get_forecast_registry(self, model_name: str) -> None: Returns: ModelRegistry: The ModelRegistry associated with the model. """ - return self.forecast_registries.get(model_name) + return self.model_registries.get(model_name) - def log_forecast_trees(self, time_windows: list) -> None: - """ - Logs the forecasts for all models managed by this ExperimentRegistry. + def result_exist(self, timewindow_str: str, test_name: str, model_name: str) -> bool: """ - log.debug("===================") - log.debug(f" Total Time Windows: {len(time_windows)}") - for model_name, registry in self.forecast_registries.items(): - log.debug(f" Model: {model_name}") - registry.log_tree() - log.debug("===================") + Checks if a given test results exist - def get(self, *args: Any) -> str: - """ Args: - *args: A sequence of keys (usually models, tests and/or time-window strings) + timewindow_str (str): String representing the time window + test_name (str): Name of the evaluation + model_name (str): Name of the model - Returns: - The filepath from a sequence of key values (usually models first, then time-window - strings) """ - val = self.__dict__ - for i in args: - parsed_arg = self._parse_arg(i) - val = val[parsed_arg] - return self.abs(self.run_dir, val) + return self.file_exists("results", timewindow_str, test_name, model_name) - def get_result(self, *args: Sequence[any]) -> str: + def get_test_catalog_key(self, *args: Sequence[any]) -> str: """ - Gets the file path of an evaluation result. + Gets the file path of a testing catalog. Args: - args: A sequence of keys (usually models, tests and/or time-window strings) + *args: A sequence of keys (time-window strings) Returns: - The filepath of a serialized result + The filepath of the testing catalog for a given time-window """ - val = self.results + val = self.test_catalogs for i in args: parsed_arg = self._parse_arg(i) val = val[parsed_arg] return self.abs(self.run_dir, val) - def get_test_catalog(self, *args: Sequence[any]) -> str: + def get_result_key(self, *args: Sequence[any]) -> str: """ - Gets the file path of a testing catalog. + Gets the file path of an evaluation result. Args: - *args: A sequence of keys (time-window strings) + args: A sequence of keys (usually models, tests and/or time-window strings) Returns: - The filepath of the testing catalog for a given time-window + The filepath of a serialized result """ - val = self.test_catalogs + val = self.results for i in args: parsed_arg = self._parse_arg(i) val = val[parsed_arg] return self.abs(self.run_dir, val) - def get_figure(self, *args: Sequence[any]) -> str: + def get_figure_key(self, *args: Sequence[any]) -> str: """ Gets the file path of a result figure. @@ -488,22 +490,6 @@ def get_figure(self, *args: Sequence[any]) -> str: val = val[parsed_arg] return self.abs(self.run_dir, val) - def result_exist(self, timewindow_str: str, test_name: str, model_name: str) -> bool: - """ - Checks if a given test results exist - - Args: - timewindow_str (str): String representing the time window - test_name (str): Name of the evaluation - model_name (str): Name of the model - - """ - return self.file_exists("results", timewindow_str, test_name, model_name) - - def as_dict(self) -> str: - # todo: rework - return self.workdir - def build_tree( self, time_windows: Sequence[Sequence[datetime]], @@ -567,46 +553,7 @@ def build_tree( self.test_catalogs = test_catalogs self.figures = figures - def log_results_tree(self): - """ - Logs a summary of the results dictionary, sorted by test. - For each test and time window, it logs whether all models have results, - or if some results are missing, and specifies which models are missing. - """ - log.debug("===================") - - total_results = results_exist_count = results_not_exist_count = 0 - - # Get all unique test names and sort them - all_tests = sorted( - {test_name for tests in self.results.values() for test_name in tests} - ) - - for test_name in all_tests: - log.debug(f"Test: {test_name}") - for timewindow, tests in self.results.items(): - if test_name in tests: - models = tests[test_name] - missing_models = [] - - for model_name, result_path in models.items(): - total_results += 1 - result_full_path = self.get_result(timewindow, test_name, model_name) - if os.path.exists(result_full_path): - results_exist_count += 1 - else: - results_not_exist_count += 1 - missing_models.append(model_name) - - if not missing_models: - log.debug(f" Time Window: {timewindow} - All models evaluated.") - else: - log.debug( - f" Time Window: {timewindow} - Missing results for models: " - f"{', '.join(missing_models)}" - ) - - log.debug(f"Total Results: {total_results}") - log.debug(f"Results that Exist: {results_exist_count}") - log.debug(f"Results that Do Not Exist: {results_not_exist_count}") - log.debug("===================") + def as_dict(self) -> str: + + return self.workdir + diff --git a/floatcsep/infrastructure/repositories.py b/floatcsep/infrastructure/repositories.py index b138814..7f4f685 100644 --- a/floatcsep/infrastructure/repositories.py +++ b/floatcsep/infrastructure/repositories.py @@ -243,7 +243,7 @@ def _load_result( else: wstr_ = window - eval_path = self.registry.get_result(wstr_, test, model) + eval_path = self.registry.get_result_key(wstr_, test, model) with open(eval_path, "r") as file_: model_eval = EvaluationResult.from_dict(json.load(file_)) @@ -287,7 +287,7 @@ def write_result(self, result: EvaluationResult, test, model, window) -> None: window: Name of the time-window """ - path = self.registry.get_result(window, test, model) + path = self.registry.get_result_key(window, test, model) class NumpyEncoder(json.JSONEncoder): def default(self, obj): @@ -471,7 +471,7 @@ def set_test_cat(self, tstring: str) -> None: tstring (str): Time window string """ - testcat_name = self.registry.get_test_catalog(tstring) + testcat_name = self.registry.get_test_catalog_key(tstring) if not exists(testcat_name): log.debug( f"Filtering testing catalog and saving to {self.registry.rel(testcat_name)}" diff --git a/floatcsep/postprocess/plot_handler.py b/floatcsep/postprocess/plot_handler.py index 8b70532..2de233f 100644 --- a/floatcsep/postprocess/plot_handler.py +++ b/floatcsep/postprocess/plot_handler.py @@ -106,7 +106,7 @@ def plot_forecasts(experiment: "Experiment") -> None: } ), ) - fig_path = experiment.registry.get_figure(window, "forecasts", model.name) + fig_path = experiment.registry.get_figure_key(window, "forecasts", model.name) pyplot.savefig(fig_path, dpi=plot_forecast_config.get("dpi", 300)) @@ -167,12 +167,12 @@ def plot_catalogs(experiment: "Experiment") -> None: # Plot catalog map ax = main_catalog.plot(plot_args=plot_catalog_config) - cat_map_path = experiment.registry.get_figure("main_catalog_map") + cat_map_path = experiment.registry.get_figure_key("main_catalog_map") ax.get_figure().savefig(cat_map_path, dpi=plot_catalog_config.get("dpi", 300)) # Plot catalog time series vs. magnitude ax = magnitude_vs_time(main_catalog) - cat_time_path = experiment.registry.get_figure("main_catalog_time") + cat_time_path = experiment.registry.get_figure_key("main_catalog_time") ax.get_figure().savefig(cat_time_path, dpi=plot_catalog_config.get("dpi", 300)) # If selected, plot the test catalogs for each of the time windows @@ -185,11 +185,11 @@ def plot_catalogs(experiment: "Experiment") -> None: continue ax = test_catalog.plot(plot_args=plot_catalog_config) - cat_map_path = experiment.registry.get_figure(tw, "catalog_map") + cat_map_path = experiment.registry.get_figure_key(tw, "catalog_map") ax.get_figure().savefig(cat_map_path, dpi=plot_catalog_config.get("dpi", 300)) ax = magnitude_vs_time(test_catalog) - cat_time_path = experiment.registry.get_figure(tw, "catalog_time") + cat_time_path = experiment.registry.get_figure_key(tw, "catalog_time") ax.get_figure().savefig(cat_time_path, dpi=plot_catalog_config.get("dpi", 300)) diff --git a/floatcsep/postprocess/reporting.py b/floatcsep/postprocess/reporting.py index 0f1d9f7..cacae0c 100644 --- a/floatcsep/postprocess/reporting.py +++ b/floatcsep/postprocess/reporting.py @@ -60,11 +60,11 @@ def generate_report(experiment, timewindow=-1): "Input catalog", [ os.path.relpath( - experiment.registry.get_figure("main_catalog_map"), + experiment.registry.get_figure_key("main_catalog_map"), experiment.registry.run_dir, ), os.path.relpath( - experiment.registry.get_figure("main_catalog_time"), + experiment.registry.get_figure_key("main_catalog_time"), experiment.registry.run_dir, ), ], @@ -81,7 +81,7 @@ def generate_report(experiment, timewindow=-1): # Include results from Experiment for test in experiment.tests: - fig_path = experiment.registry.get_figure(timestr, test) + fig_path = experiment.registry.get_figure_key(timestr, test) width = test.plot_args[0].get("figsize", [4])[0] * 96 report.add_figure( f"{test.name}", @@ -93,7 +93,7 @@ def generate_report(experiment, timewindow=-1): ) for model in experiment.models: try: - fig_path = experiment.registry.get_figure(timestr, f"{test.name}_{model.name}") + fig_path = experiment.registry.get_figure_key(timestr, f"{test.name}_{model.name}") width = test.plot_args[0].get("figsize", [4])[0] * 96 report.add_figure( f"{test.name}: {model.name}", diff --git a/tests/unit/test_plot_handler.py b/tests/unit/test_plot_handler.py index ebabb3f..828b1f3 100644 --- a/tests/unit/test_plot_handler.py +++ b/tests/unit/test_plot_handler.py @@ -63,7 +63,7 @@ def test_plot_catalogs( mock_parse_plot_config.return_value = {"projection": "Mercator"} mock_parse_projection.return_value = MagicMock() - mock_experiment.registry.get_figure.return_value = "cat.png" + mock_experiment.registry.get_figure_key.return_value = "cat.png" plot_handler.plot_catalogs(mock_experiment) diff --git a/tests/unit/test_registry.py b/tests/unit/test_registry.py index 03ba40e..587523e 100644 --- a/tests/unit/test_registry.py +++ b/tests/unit/test_registry.py @@ -1,7 +1,8 @@ +import os import unittest from datetime import datetime from unittest.mock import patch, MagicMock -from floatcsep.infrastructure.registries import ModelFileRegistry +from floatcsep.infrastructure.registries import ModelFileRegistry, ExperimentFileRegistry class TestModelFileRegistry(unittest.TestCase): @@ -86,9 +87,85 @@ def test_build_tree_time_dependent(self, mock_listdir, mock_makedirs): time_windows=time_windows, model_class="TimeDependentModel", prefix="forecast" ) self.assertIn("2023-01-01_2023-01-02", self.registry_for_folderbased_model.forecasts) - # self.assertTrue(self.registry_for_folderbased_model.inventory["2023-01-01_2023-01-02"]) self.assertIn("2023-01-02_2023-01-03", self.registry_for_folderbased_model.forecasts) - # self.assertTrue(self.registry_for_folderbased_model.inventory["2023-01-02_2023-01-03"]) + + +class TestExperimentFileRegistry(unittest.TestCase): + + def setUp(self): + self.registry = ExperimentFileRegistry(workdir="/test/workdir") + + def test_initialization(self): + self.assertEqual(self.registry.workdir, "/test/workdir") + self.assertEqual(self.registry.run_dir, "results") + self.assertEqual(self.registry.results, {}) + self.assertEqual(self.registry.test_catalogs, {}) + self.assertEqual(self.registry.figures, {}) + self.assertEqual(self.registry.model_registries, {}) + + def test_add_and_get_model_registry(self): + model_mock = MagicMock() + model_mock.name = "TestModel" + model_mock.registry = MagicMock(spec=ModelFileRegistry) + + self.registry.add_model_registry(model_mock) + self.assertIn("TestModel", self.registry.model_registries) + self.assertEqual(self.registry.get_model_registry("TestModel"), model_mock.registry) + + @patch("os.makedirs") + def test_build_tree(self, mock_makedirs): + time_windows = [[datetime(2023, 1, 1), datetime(2023, 1, 2)]] + models = [MagicMock(name="Model1"), MagicMock(name="Model2")] + tests = [MagicMock(name="Test1")] + + self.registry.build_tree(time_windows, models, tests) + + timewindow_str = "2023-01-01_2023-01-02" + self.assertIn(timewindow_str, self.registry.results) + self.assertIn(timewindow_str, self.registry.test_catalogs) + self.assertIn(timewindow_str, self.registry.figures) + + def test_get_test_catalog_key(self): + self.registry.test_catalogs = {"2023-01-01_2023-01-02": "some/path/to/catalog.json"} + result = self.registry.get_test_catalog_key("2023-01-01_2023-01-02") + self.assertTrue(result.endswith("results/some/path/to/catalog.json")) + + def test_get_result_key(self): + self.registry.results = { + "2023-01-01_2023-01-02": { + "Test1": { + "Model1": "some/path/to/result.json" + } + } + } + result = self.registry.get_result_key("2023-01-01_2023-01-02", "Test1", "Model1") + self.assertTrue(result.endswith("results/some/path/to/result.json")) + + def test_get_figure_key(self): + self.registry.figures = { + "2023-01-01_2023-01-02": { + "Test1": "some/path/to/figure.png", + "catalog_map": "some/path/to/catalog_map.png", + "catalog_time": "some/path/to/catalog_time.png", + "forecasts": {"Model1": "some/path/to/forecast.png"} + } + } + result = self.registry.get_figure_key("2023-01-01_2023-01-02", "Test1") + self.assertTrue(result.endswith("results/some/path/to/figure.png")) + + @patch("floatcsep.infrastructure.registries.exists") + def test_result_exist(self, mock_exists): + mock_exists.return_value = True + self.registry.results = { + "2023-01-01_2023-01-02": { + "Test1": { + "Model1": "some/path/to/result.json" + } + } + } + result = self.registry.result_exist("2023-01-01_2023-01-02", "Test1", "Model1") + self.assertTrue(result) + mock_exists.assert_called() if __name__ == "__main__": diff --git a/tests/unit/test_reporting.py b/tests/unit/test_reporting.py index 709f745..110efff 100644 --- a/tests/unit/test_reporting.py +++ b/tests/unit/test_reporting.py @@ -25,7 +25,7 @@ def test_generate_standard_report(self, mock_markdown_report): # Mock experiment without a custom report function mock_experiment = MagicMock() mock_experiment.postprocess.get.return_value = None - mock_experiment.registry.get_figure.return_value = "figure_path" + mock_experiment.registry.get_figure_key.return_value = "figure_path" mock_experiment.magnitudes = [0, 1] # Call the generate_report function reporting.generate_report(mock_experiment) diff --git a/tests/unit/test_repositories.py b/tests/unit/test_repositories.py index c0f567b..c6c7d5e 100644 --- a/tests/unit/test_repositories.py +++ b/tests/unit/test_repositories.py @@ -160,9 +160,10 @@ def test_equal(self, MockModelFileRegistry): class TestResultsRepository(unittest.TestCase): - @patch("floatcsep.infrastructure.repositories.ExperimentRegistry") - def setUp(self, MockRegistry): - self.mock_registry = MockRegistry() + @patch("floatcsep.infrastructure.repositories.ExperimentRegistry.factory") + def setUp(self, mock_registry): + self.mock_registry = MagicMock() + self.mock_registry.return_value = mock_registry() self.results_repo = ResultsRepository(self.mock_registry) def test_initialization(self): @@ -191,9 +192,10 @@ def test_write_result(self, mock_open, mock_json_dump): class TestCatalogRepository(unittest.TestCase): - @patch("floatcsep.infrastructure.repositories.ExperimentRegistry") - def setUp(self, MockRegistry): - self.mock_registry = MockRegistry() + @patch("floatcsep.infrastructure.repositories.ExperimentRegistry.factory") + def setUp(self, mock_registry): + self.mock_registry = MagicMock() + self.mock_registry.return_value = mock_registry() self.catalog_repo = CatalogRepository(self.mock_registry) def test_initialization(self): diff --git a/tutorials/case_h/custom_report.py b/tutorials/case_h/custom_report.py index d705d5d..b8ed4b2 100644 --- a/tutorials/case_h/custom_report.py +++ b/tutorials/case_h/custom_report.py @@ -30,8 +30,8 @@ def main(experiment): report.add_figure( f"Input catalog", [ - experiment.registry.get_figure("main_catalog_map"), - experiment.registry.get_figure("main_catalog_time"), + experiment.registry.get_figure_key("main_catalog_map"), + experiment.registry.get_figure_key("main_catalog_time"), ], level=3, ncols=1, @@ -44,7 +44,7 @@ def main(experiment): # Include results from Experiment test = experiment.tests[0] for model in experiment.models: - fig_path = experiment.registry.get_figure(timestr, f"{test.name}_{model.name}") + fig_path = experiment.registry.get_figure_key(timestr, f"{test.name}_{model.name}") report.add_figure( f"{test.name}: {model.name}", fig_path, From 1ee572fb0a0f642d6f38121d9b1b3405be7d0e07 Mon Sep 17 00:00:00 2001 From: pciturri Date: Tue, 13 May 2025 17:17:46 +0200 Subject: [PATCH 5/5] fix: abstracted catalog forecast repository from IO operations --- floatcsep/infrastructure/registries.py | 3 - floatcsep/infrastructure/repositories.py | 29 ++- floatcsep/model.py | 4 +- floatcsep/utils/helpers.py | 4 +- floatcsep/utils/readers.py | 166 +++++++++++++++++- .../forecasts/mock_2020-01-01_2020-01-02.csv | 2 +- .../forecasts/mock_2020-01-02_2020-01-03.csv | 2 +- .../integration/test_model_infrastructure.py | 3 +- tests/unit/test_readers.py | 14 +- tests/unit/test_repositories.py | 16 +- 10 files changed, 204 insertions(+), 39 deletions(-) diff --git a/floatcsep/infrastructure/registries.py b/floatcsep/infrastructure/registries.py index 3fcbfd0..f7eaff5 100644 --- a/floatcsep/infrastructure/registries.py +++ b/floatcsep/infrastructure/registries.py @@ -13,7 +13,6 @@ log = logging.getLogger("floatLogger") - class FilepathMixin: """ Small mixin to provide filepath management functionality to Registries that uses files to @@ -147,7 +146,6 @@ def factory(cls, registry_type: str = 'file', **kwargs) -> "ModelRegistry": elif registry_type == 'hdf5': return ModelHDF5Registry(**kwargs) - class ModelFileRegistry(ModelRegistry, FilepathMixin): def __init__( self, @@ -318,7 +316,6 @@ def as_dict(self) -> dict: "forecasts": self.forecasts, } - class ModelHDF5Registry(ModelRegistry): def __init__(self, workdir: str, path: str): diff --git a/floatcsep/infrastructure/repositories.py b/floatcsep/infrastructure/repositories.py index 7f4f685..572d90b 100644 --- a/floatcsep/infrastructure/repositories.py +++ b/floatcsep/infrastructure/repositories.py @@ -12,7 +12,7 @@ from csep.models import EvaluationResult from csep.utils.time_utils import decimal_year -from floatcsep.utils.readers import ForecastParsers +from floatcsep.utils.readers import GriddedForecastParsers, CatalogForecastParsers from floatcsep.infrastructure.registries import ExperimentRegistry, ModelRegistry from floatcsep.utils.helpers import str2timewindow, parse_csep_func from floatcsep.utils.helpers import timewindow2str @@ -102,7 +102,7 @@ def __init__(self, registry: ModelRegistry, **kwargs): self.forecasts = {} def load_forecast( - self, tstring: Union[str, list], region=None + self, tstring: Union[str, list], region=None, n_sims=None, ) -> Union[CatalogForecast, list[CatalogForecast]]: """ Returns a forecast object or a sequence of them for a set of time window strings. @@ -110,20 +110,31 @@ def load_forecast( Args: tstring (str, list): String representing the time-window region (optional): A region, in case the forecast requires to be filtered lazily. + n_sims (optional: The number of simulations/synthetic catalogs of the forecast Returns: The CSEP CatalogForecast object or a list of them. """ if isinstance(tstring, str): - return self._load_single_forecast(tstring, region) + return self._load_single_forecast(tstring, region=region, n_sims=n_sims) else: return [self._load_single_forecast(t, region) for t in tstring] - def _load_single_forecast(self, t: str, region=None): - fc_path = self.registry.get_forecast_key(t) - return csep.load_catalog_forecast( - fc_path, region=region, apply_filters=True, filter_spatial=True - ) + def _load_single_forecast(self, tstring: str, region=None, n_sims=None): + start_date, end_date = str2timewindow(tstring) + + fc_path = self.registry.get_forecast_key(tstring) + f_parser = getattr(CatalogForecastParsers, self.registry.fmt) + + forecast_ = f_parser(fc_path, + start_time=start_date, + end_time=end_date, + n_cat=n_sims, + region=region, + apply_filters=True, + filter_spatial=True, + ) + return forecast_ def remove(self, tstring: Union[str, Sequence[str]]): pass @@ -190,7 +201,7 @@ def _load_single_forecast(self, tstring: str, fc_unit: float = 1, name_=""): tstring_ = timewindow2str([start_date, end_date]) f_path = self.registry.get_forecast_key(tstring_) - f_parser = getattr(ForecastParsers, self.registry.fmt) + f_parser = getattr(GriddedForecastParsers, self.registry.fmt) rates, region, mags = f_parser(f_path) diff --git a/floatcsep/model.py b/floatcsep/model.py index d5c48dc..5d3b430 100644 --- a/floatcsep/model.py +++ b/floatcsep/model.py @@ -11,7 +11,7 @@ from floatcsep.utils.accessors import from_zenodo, from_git from floatcsep.infrastructure.environments import EnvironmentFactory -from floatcsep.utils.readers import ForecastParsers, HDF5Serializer +from floatcsep.utils.readers import GriddedForecastParsers, HDF5Serializer from floatcsep.infrastructure.registries import ModelRegistry from floatcsep.infrastructure.repositories import ForecastRepository from floatcsep.utils.helpers import timewindow2str, str2timewindow, parse_nested_dicts @@ -247,7 +247,7 @@ def init_db(self, dbpath: str = "", force: bool = False) -> None: exists """ - parser = getattr(ForecastParsers, self.registry.fmt) + parser = getattr(GriddedForecastParsers, self.registry.fmt) rates, region, mag = parser(self.registry.get_attr("path")) db_func = HDF5Serializer.grid2hdf5 diff --git a/floatcsep/utils/helpers.py b/floatcsep/utils/helpers.py index 3a1eda9..c461b2c 100644 --- a/floatcsep/utils/helpers.py +++ b/floatcsep/utils/helpers.py @@ -75,7 +75,9 @@ def _getattr(obj_, attr_): floatcsep.utils.helpers, floatcsep.utils.accessors, floatcsep.utils.readers.HDF5Serializer, - floatcsep.utils.readers.ForecastParsers, + floatcsep.utils.readers.GriddedForecastParsers, + floatcsep.utils.readers.CatalogForecastParsers, + ] for module in _target_modules: try: diff --git a/floatcsep/utils/readers.py b/floatcsep/utils/readers.py index aa02804..dbe7ea0 100644 --- a/floatcsep/utils/readers.py +++ b/floatcsep/utils/readers.py @@ -1,19 +1,173 @@ import argparse +import csv import logging import os.path import time import xml.etree.ElementTree as eTree +import csep import h5py import numpy import pandas +import pandas as pd +from csep.core.catalogs import CSEPCatalog from csep.core.regions import QuadtreeGrid2D, CartesianGrid2D from csep.models import Polygon +from csep.utils.time_utils import strptime_to_utc_epoch log = logging.getLogger(__name__) +class CatalogForecastParsers: -class ForecastParsers: + @staticmethod + def csv(filename, **kwargs): + csep_headers = ['lon', 'lat', 'magnitude', 'time_string', 'depth', 'catalog_id', + 'event_id'] + hermes_headers = ['realization_id', 'magnitude', 'depth', 'latitude', 'longitude', + 'time'] + headers_df = pd.read_csv(filename, nrows=0).columns.str.strip().to_list() + + # CSEP headers + if headers_df[:2] == csep_headers[:2]: + + return csep.load_catalog_forecast(filename, **kwargs) + + elif headers_df == hermes_headers: + return csep.load_catalog_forecast(filename, + catalog_loader=CatalogForecastParsers.load_hermes_catalog, + **kwargs + ) + else: + raise Exception('Catalog Forecast could not be loaded') + + @staticmethod + def load_hermes_catalog(filename, **kwargs): + """ Loads hermes synthetic catalogs in csep-ascii format. + + This function can load multiple catalogs stored in a single file. This typically called to + load a catalog-based forecast, but could also load a collection of catalogs stored in the same file + + Args: + filename (str): filepath or directory of catalog files + **kwargs (dict): passed to class constructor + + Return: + yields CSEPCatalog class + """ + + def read_float(val): + """Returns val as float or None if unable""" + try: + val = float(val) + except: + val = None + return val + + def is_header_line(line): + if line[0].lower() == 'realization_id': + return True + else: + return False + + def read_catalog_line(line): + # convert to correct types + + catalog_id = int(line[0]) + magnitude = read_float(line[1]) + depth = read_float(line[2]) + lat = read_float(line[3]) + lon = read_float(line[4]) + # maybe fractional seconds are not included + origin_time = line[5] + if origin_time: + try: + origin_time = strptime_to_utc_epoch(origin_time, + format='%Y-%m-%d %H:%M:%S.%f') + except ValueError: + origin_time = strptime_to_utc_epoch(origin_time, + format='%Y-%m-%d %H:%M:%S') + + event_id = 0 + # temporary event + temp_event = (event_id, origin_time, lat, lon, depth, magnitude) + return temp_event, catalog_id + + # handle all catalogs in single file + if os.path.isfile(filename): + with open(filename, 'r', newline='') as input_file: + catalog_reader = csv.reader(input_file, delimiter=',') + # csv treats everything as a string convert to correct types + events = [] + # all catalogs should start at zero + prev_id = None + for line in catalog_reader: + # skip header line on first read if included in file + if prev_id is None: + if is_header_line(line): + continue + # read line and return catalog id + temp_event, catalog_id = read_catalog_line(line) + empty = False + # OK if event_id is empty + if all([val in (None, '') for val in temp_event[1:]]): + empty = True + # first event is when prev_id is none, catalog_id should always start at zero + if prev_id is None: + prev_id = 0 + # if the first catalog doesn't start at zero + if catalog_id != prev_id: + if not empty: + events = [temp_event] + else: + events = [] + for id in range(catalog_id): + yield CSEPCatalog(data=[], catalog_id=id, **kwargs) + prev_id = catalog_id + continue + # accumulate event if catalog_id is the same as previous event + if catalog_id == prev_id: + if not all([val in (None, '') for val in temp_event]): + events.append(temp_event) + prev_id = catalog_id + # create and yield class if the events are from different catalogs + elif catalog_id == prev_id + 1: + yield CSEPCatalog(data=events, catalog_id=prev_id, **kwargs) + # add event to new event list + if not empty: + events = [temp_event] + else: + events = [] + prev_id = catalog_id + # this implies there are empty catalogs, because they are not listed in the ascii file + elif catalog_id > prev_id + 1: + yield CSEPCatalog(data=events, catalog_id=prev_id, **kwargs) + # if prev_id = 0 and catalog_id = 2, then we skipped one catalog. thus, we skip catalog_id - prev_id - 1 catalogs + num_empty_catalogs = catalog_id - prev_id - 1 + # first yield empty catalog classes + for id in range(num_empty_catalogs): + yield CSEPCatalog(data=[], + catalog_id=catalog_id - num_empty_catalogs + id, + **kwargs) + prev_id = catalog_id + # add event to new event list + if not empty: + events = [temp_event] + else: + events = [] + else: + raise ValueError( + "catalog_id should be monotonically increasing and events should be ordered by catalog_id") + # yield final catalog, note: since this is just loading catalogs, it has no idea how many should be there + cat = CSEPCatalog(data=events, catalog_id=prev_id, **kwargs) + yield cat + + elif os.path.isdir(filename): + raise NotImplementedError( + "reading from directory or batched files not implemented yet!") + + + +class GriddedForecastParsers: @staticmethod def dat(filename): @@ -151,7 +305,7 @@ def is_mag(num): sep = " " if "tile" in line: - rates, region, magnitudes = ForecastParsers.quadtree(filename) + rates, region, magnitudes = GriddedForecastParsers.quadtree(filename) return rates, region, magnitudes data = pandas.read_csv( @@ -308,13 +462,13 @@ def serialize(): args = parser.parse_args() if args.format == "quadtree": - ForecastParsers.quadtree(args.filename) + GriddedForecastParsers.quadtree(args.filename) if args.format == "dat": - ForecastParsers.dat(args.filename) + GriddedForecastParsers.dat(args.filename) if args.format == "csep" or args.format == "csv": - ForecastParsers.csv(args.filename) + GriddedForecastParsers.csv(args.filename) if args.format == "xml": - ForecastParsers.xml(args.filename) + GriddedForecastParsers.xml(args.filename) if __name__ == "__main__": diff --git a/tests/artifacts/models/td_model/forecasts/mock_2020-01-01_2020-01-02.csv b/tests/artifacts/models/td_model/forecasts/mock_2020-01-01_2020-01-02.csv index fdddae6..c05baeb 100644 --- a/tests/artifacts/models/td_model/forecasts/mock_2020-01-01_2020-01-02.csv +++ b/tests/artifacts/models/td_model/forecasts/mock_2020-01-01_2020-01-02.csv @@ -1,2 +1,2 @@ -lon, lat, M, time_string, depth, catalog_id, event_id +lon,lat,M,time_string,depth,catalog_id,event_id 1.0,1.0,5.0,2020-01-01T01:01:01.0,10.0,1,1 \ No newline at end of file diff --git a/tests/artifacts/models/td_model/forecasts/mock_2020-01-02_2020-01-03.csv b/tests/artifacts/models/td_model/forecasts/mock_2020-01-02_2020-01-03.csv index f89633b..ad21ed4 100644 --- a/tests/artifacts/models/td_model/forecasts/mock_2020-01-02_2020-01-03.csv +++ b/tests/artifacts/models/td_model/forecasts/mock_2020-01-02_2020-01-03.csv @@ -1,2 +1,2 @@ -lon, lat, M, time_string, depth, catalog_id, event_id +lon,lat,M,time_string,depth,catalog_id,event_id 1.0,1.0,5.0,2020-01-02T01:01:01.0,10.0,1,1 \ No newline at end of file diff --git a/tests/integration/test_model_infrastructure.py b/tests/integration/test_model_infrastructure.py index 7dda39f..81b1bdb 100644 --- a/tests/integration/test_model_infrastructure.py +++ b/tests/integration/test_model_infrastructure.py @@ -39,7 +39,6 @@ def test_time_independent_model_stage(self): [datetime(2023, 1, 1), datetime(2023, 1, 2)], ] self.time_independent_model.stage(time_windows=time_windows) - print("a", self.time_independent_model.registry.as_dict()) self.assertIn("2023-01-01_2023-01-02", self.time_independent_model.registry.forecasts) def test_time_independent_model_get_forecast(self): @@ -123,7 +122,7 @@ def forecast_(_): name = "mock" fname = os.path.join(self._dir, "model.csv") - with patch("floatcsep.readers.ForecastParsers.csv", forecast_): + with patch("floatcsep.readers.GriddedForecastParsers.csv", forecast_): model = self.init_model(name, fname) model.registry.build_tree([[start, end]]) forecast = model.get_forecast(timestring) diff --git a/tests/unit/test_readers.py b/tests/unit/test_readers.py index f4935a1..1f552e3 100644 --- a/tests/unit/test_readers.py +++ b/tests/unit/test_readers.py @@ -25,7 +25,7 @@ def tearDownClass(cls) -> None: def test_parse_csv(self): fname = os.path.join(self._dir, "model.csv") numpy.seterr(all="ignore") - rates, region, mags = readers.ForecastParsers.csv(fname) + rates, region, mags = readers.GriddedForecastParsers.csv(fname) rts = numpy.array([[1.0, 0.1], [1.0, 0.1], [1.0, 0.1], [1.0, 0.1]]) orgs = numpy.array([[0.0, 0.0], [0.1, 0], [0.0, 0.1], [0.1, 0.1]]) @@ -39,7 +39,7 @@ def test_parse_csv(self): def test_parse_dat(self): fname = csep.utils.datasets.helmstetter_mainshock_fname - rates, region, mags = readers.ForecastParsers.dat(fname) + rates, region, mags = readers.GriddedForecastParsers.dat(fname) forecast = csep.load_gridded_forecast(fname) self.assertEqual(forecast.region, region) @@ -50,7 +50,7 @@ def test_parse_csv_qtree(self): fname = os.path.join(self._dir, "qtree", "TEAM=N10L11.csv") numpy.seterr(all="ignore") - rates, region, mags = readers.ForecastParsers.csv(fname) + rates, region, mags = readers.GriddedForecastParsers.csv(fname) poly = numpy.array( [[-180.0, 66.51326], [-180.0, 79.171335], [-135.0, 79.171335], [-135.0, 66.51326]] @@ -61,7 +61,7 @@ def test_parse_csv_qtree(self): self.assertEqual(8089, rates.shape[0]) numpy.testing.assert_allclose(poly, region.polygons[2].points) - rates2, region2, mags2 = readers.ForecastParsers.quadtree(fname) + rates2, region2, mags2 = readers.GriddedForecastParsers.quadtree(fname) numpy.testing.assert_allclose(rates, rates2) numpy.testing.assert_allclose( [i.points for i in region.polygons], [i.points for i in region2.polygons] @@ -78,7 +78,7 @@ def test_parse_xml(self): ) numpy.seterr(all="ignore") - rates, region, mags = readers.ForecastParsers.xml(fname) + rates, region, mags = readers.GriddedForecastParsers.xml(fname) orgs = numpy.array([12.6, 38.3]) poly = numpy.array([[12.6, 38.3], [12.6, 38.4], [12.7, 38.4], [12.7, 38.3]]) @@ -94,7 +94,7 @@ def test_parse_xml(self): def test_serialize_hdf5(self): numpy.seterr(all="ignore") fname = os.path.join(self._dir, "model.csv") - rates, region, mags = readers.ForecastParsers.csv(fname) + rates, region, mags = readers.GriddedForecastParsers.csv(fname) fname_db = os.path.join(self._dir, "model.hdf5") readers.HDF5Serializer.grid2hdf5(rates, region, mags, hdf5_filename=fname_db) @@ -105,7 +105,7 @@ def test_serialize_hdf5(self): def test_parse_hdf5(self): fname = os.path.join(self._dir, "model_h5.hdf5") - rates, region, mags = readers.ForecastParsers.hdf5(fname) + rates, region, mags = readers.GriddedForecastParsers.hdf5(fname) orgs = numpy.array([[0.0, 0.0], [0.1, 0], [0.0, 0.1], [0.1, 0.1]]) poly_3 = numpy.array([[0.1, 0.1], [0.1, 0.2], [0.2, 0.2], [0.2, 0.1]]) diff --git a/tests/unit/test_repositories.py b/tests/unit/test_repositories.py index c6c7d5e..335ed66 100644 --- a/tests/unit/test_repositories.py +++ b/tests/unit/test_repositories.py @@ -4,7 +4,7 @@ from csep.core.forecasts import GriddedForecast -from floatcsep.utils.readers import ForecastParsers +from floatcsep.utils.readers import GriddedForecastParsers from floatcsep.infrastructure.registries import ModelFileRegistry from floatcsep.infrastructure.repositories import ( CatalogForecastRepository, @@ -19,15 +19,17 @@ class TestCatalogForecastRepository(unittest.TestCase): def setUp(self): self.registry = MagicMock(spec=ModelFileRegistry) #todo: Factory registry self.registry.__call__ = MagicMock(return_value="a_duck") + self.registry.fmt = 'csv' @patch("csep.load_catalog_forecast") def test_initialization(self, mock_load_catalog_forecast): repo = CatalogForecastRepository(self.registry, lazy_load=True) self.assertTrue(repo.lazy_load) - @patch("csep.load_catalog_forecast") + @patch("floatcsep.readers.CatalogForecastParsers.csv") def test_load_forecast(self, mock_load_catalog_forecast): repo = CatalogForecastRepository(self.registry) + mock_load_catalog_forecast.return_value = "forecatto" forecast = repo.load_forecast("2023-01-01_2023-01-02") self.assertEqual(forecast, "forecatto") @@ -36,7 +38,7 @@ def test_load_forecast(self, mock_load_catalog_forecast): forecasts = repo.load_forecast(["2023-01-01_2023-01-01", "2023-01-02_2023-01-03"]) self.assertEqual(forecasts, ["forecatto", "forecatto"]) - @patch("csep.load_catalog_forecast") + @patch("floatcsep.readers.CatalogForecastParsers.csv") def test_load_single_forecast(self, mock_load_catalog_forecast): # Test _load_single_forecast repo = CatalogForecastRepository(self.registry) @@ -56,7 +58,7 @@ def test_initialization(self): repo = GriddedForecastRepository(self.registry, lazy_load=False) self.assertFalse(repo.lazy_load) - @patch.object(ForecastParsers, "hdf5") + @patch.object(GriddedForecastParsers, "hdf5") def test_load_forecast(self, mock_parser): # Mock parser return values mock_parser.return_value = ("rates", "region", "mags") @@ -77,7 +79,7 @@ def test_load_forecast(self, mock_parser): self.assertEqual(forecasts, ["forecatto", "forecatto"]) self.assertEqual(mock_method.call_count, 2) - @patch.object(ForecastParsers, "hdf5") + @patch.object(GriddedForecastParsers, "hdf5") def test_get_or_load_forecast(self, mock_parser): mock_parser.return_value = ("rates", "region", "mags") repo = GriddedForecastRepository(self.registry, lazy_load=False) @@ -98,7 +100,7 @@ def test_get_or_load_forecast(self, mock_parser): @patch.object(GriddedForecast, "__init__", return_value=None) @patch.object(GriddedForecast, "event_count", new_callable=PropertyMock) @patch.object(GriddedForecast, "scale") - @patch.object(ForecastParsers, "hdf5") + @patch.object(GriddedForecastParsers, "hdf5") def test_load_single_forecast(self, mock_parser, mock_scale, mock_count, mock_init): # Mock parser return values mock_count.return_value = 2 @@ -119,7 +121,7 @@ def test_load_single_forecast(self, mock_parser, mock_scale, mock_count, mock_in end_time=datetime.datetime(2024, 1, 1), ) - @patch.object(ForecastParsers, "hdf5") + @patch.object(GriddedForecastParsers, "hdf5") def test_lazy_load_behavior(self, mock_parser): mock_parser.return_value = ("rates", "region", "mags") # Test lazy_load behavior