diff --git a/floatcsep/evaluation.py b/floatcsep/evaluation.py index 93c25ba..fef6f69 100644 --- a/floatcsep/evaluation.py +++ b/floatcsep/evaluation.py @@ -287,7 +287,7 @@ def plot_results( # Regular consistency/comparative test plots (e.g., many models) try: for time_str in timewindow: - fig_path = registry.get_figure(time_str, self.name) + fig_path = registry.get_figure_key(time_str, self.name) results = self.read_results(time_str, models) ax = func(results, plot_args=fargs, **fkwargs) if "code" in fargs: @@ -307,7 +307,7 @@ def plot_results( registry.figures[time_str][fig_name] = os.path.join( time_str, "figures", fig_name ) - fig_path = registry.get_figure(time_str, fig_name) + fig_path = registry.get_figure_key(time_str, fig_name) ax = func(result, plot_args=fargs, **fkwargs, show=False) if "code" in fargs: exec(fargs["code"]) @@ -318,7 +318,7 @@ def plot_results( pyplot.show() elif self.type in ["sequential", "sequential_comparative", "batch"]: - fig_path = registry.get_figure(timewindow[-1], self.name) + fig_path = registry.get_figure_key(timewindow[-1], self.name) results = self.read_results(timewindow[-1], models) ax = func(results, plot_args=fargs, **fkwargs) diff --git a/floatcsep/experiment.py b/floatcsep/experiment.py index da2930a..12be522 100644 --- a/floatcsep/experiment.py +++ b/floatcsep/experiment.py @@ -25,6 +25,7 @@ parse_nested_dicts, ) from floatcsep.infrastructure.engine import Task, TaskGraph +from floatcsep.infrastructure.logger import log_models_tree, log_results_tree log = logging.getLogger("floatLogger") @@ -52,8 +53,8 @@ class Experiment: - growth (:class:`str`): `incremental` or `cumulative` - offset (:class:`float`): recurrence of forecast creation. - For further details, see :func:`~floatcsep.utils.timewindows_ti` - and :func:`~floatcsep.utils.timewindows_td` + For further details, see :func:`~floatcsep.utils.time_windows_ti` + and :func:`~floatcsep.utils.time_windows_td` region_config (dict): Contains all the spatial and magnitude specifications. It must contain the following keys: @@ -118,7 +119,7 @@ def __init__( os.makedirs(os.path.join(workdir, rundir), exist_ok=True) self.name = name if name else "floatingExp" - self.registry = ExperimentRegistry(workdir, rundir) + self.registry = ExperimentRegistry.factory(workdir=workdir, run_dir=rundir) self.results_repo = ResultsRepository(self.registry) self.catalog_repo = CatalogRepository(self.registry) @@ -143,7 +144,7 @@ def __init__( log.info(f"Setting up experiment {self.name}:") log.info(f"\tStart: {self.start_date}") log.info(f"\tEnd: {self.end_date}") - log.info(f"\tTime windows: {len(self.timewindows)}") + log.info(f"\tTime windows: {len(self.time_windows)}") log.info(f"\tRegion: {self.region.name if self.region else None}") log.info( f"\tMagnitude range: [{numpy.min(self.magnitudes)}," @@ -175,7 +176,7 @@ def __getattr__(self, item: str) -> object: Override built-in method to return the experiment attributes by also using the command ``experiment.{attr}``. Adds also to the experiment scope the keys of :attr:`region_config` or :attr:`time_config`. These are: ``start_date``, ``end_date``, - ``timewindows``, ``horizon``, ``offset``, ``region``, ``magnitudes``, ``mag_min``, + ``time_windows``, ``horizon``, ``offset``, ``region``, ``magnitudes``, ``mag_min``, `mag_max``, ``mag_bin``, ``depth_min`` depth_max . """ @@ -295,8 +296,8 @@ def stage_models(self) -> None: """ log.info("Staging models") for i in self.models: - i.stage(self.timewindows) - self.registry.add_forecast_registry(i) + i.stage(self.time_windows) + self.registry.add_model_registry(i) def set_tests(self, test_config: Union[str, Dict, List]) -> list: """ @@ -376,17 +377,17 @@ def set_tasks(self) -> None: """ # Set the file path structure - self.registry.build_tree(self.timewindows, self.models, self.tests) + self.registry.build_tree(self.time_windows, self.models, self.tests) log.debug("Pre-run forecast summary") - self.registry.log_forecast_trees(self.timewindows) + log_models_tree(log, self.registry, self.time_windows) log.debug("Pre-run result summary") - self.registry.log_results_tree() + log_results_tree(log, self.registry) log.info("Setting up experiment's tasks") # Get the time windows strings - tw_strings = timewindow2str(self.timewindows) + tw_strings = timewindow2str(self.time_windows) # Prepare the testing catalogs task_graph = TaskGraph() @@ -481,7 +482,7 @@ def set_tasks(self) -> None: ) # Set up the Sequential_Comparative Scores elif test_k.type == "sequential_comparative": - tw_strs = timewindow2str(self.timewindows) + tw_strs = timewindow2str(self.time_windows) for model_j in self.models: task_k = Task( instance=test_k, @@ -504,7 +505,7 @@ def set_tasks(self) -> None: ) # Set up the Batch comparative Scores elif test_k.type == "batch": - time_str = timewindow2str(self.timewindows[-1]) + time_str = timewindow2str(self.time_windows[-1]) for model_j in self.models: task_k = Task( instance=test_k, @@ -540,9 +541,9 @@ def run(self) -> None: self.task_graph.run() log.info("Calculation completed") log.debug("Post-run forecast registry") - self.registry.log_forecast_trees(self.timewindows) + log_models_tree(log, self.registry, self.time_windows) log.debug("Post-run result summary") - self.registry.log_results_tree() + log_results_tree(log, self.registry) def read_results(self, test: Evaluation, window: str) -> List: """ @@ -559,7 +560,7 @@ def make_repr(self) -> None: """ log.info("Creating reproducibility config file") - repr_config = self.registry.get("repr_config") + repr_config = self.registry.get_attr("repr_config") # Dropping region to results folder if it is a file region_path = self.region_config.get("path", False) @@ -604,7 +605,7 @@ def as_dict(self, extra: Sequence = (), extended=False) -> dict: "time_config": { i: j for i, j in self.time_config.items() - if (i not in ("timewindows",) or extended) + if (i not in ("time_windows",) or extended) }, "region_config": { i: j @@ -731,7 +732,7 @@ def test_stat(test_orig, test_repr): def get_results(self): - win_orig = timewindow2str(self.original.timewindows) + win_orig = timewindow2str(self.original.time_windows) tests_orig = self.original.tests @@ -787,7 +788,7 @@ def get_hash(filename): def get_filecomp(self): - win_orig = timewindow2str(self.original.timewindows) + win_orig = timewindow2str(self.original.time_windows) tests_orig = self.original.tests @@ -801,8 +802,8 @@ def get_filecomp(self): for tw in win_orig: results[test.name][tw] = dict.fromkeys(models_orig) for model in models_orig: - orig_path = self.original.registry.get_result(tw, test, model) - repr_path = self.reproduced.registry.get_result(tw, test, model) + orig_path = self.original.registry.get_result_key(tw, test, model) + repr_path = self.reproduced.registry.get_result_key(tw, test, model) results[test.name][tw][model] = { "hash": (self.get_hash(orig_path) == self.get_hash(repr_path)), @@ -811,8 +812,8 @@ def get_filecomp(self): else: results[test.name] = dict.fromkeys(models_orig) for model in models_orig: - orig_path = self.original.registry.get_result(win_orig[-1], test, model) - repr_path = self.reproduced.registry.get_result(win_orig[-1], test, model) + orig_path = self.original.registry.get_result_key(win_orig[-1], test, model) + repr_path = self.reproduced.registry.get_result_key(win_orig[-1], test, model) results[test.name][model] = { "hash": (self.get_hash(orig_path) == self.get_hash(repr_path)), "byte2byte": filecmp.cmp(orig_path, repr_path), diff --git a/floatcsep/infrastructure/logger.py b/floatcsep/infrastructure/logger.py index e96fad9..0e71036 100644 --- a/floatcsep/infrastructure/logger.py +++ b/floatcsep/infrastructure/logger.py @@ -60,3 +60,74 @@ def set_console_log_level(log_level): for handler in logger.handlers: if isinstance(handler, logging.StreamHandler): handler.setLevel(log_level) + + + + +def log_models_tree(log, experiment_registry, time_windows): + """ + Logs the forecasts for all models managed by this ExperimentFileRegistry. + """ + log.debug("===================") + log.debug(f" Total Time Windows: {len(time_windows)}") + for model_name, registry in experiment_registry.model_registries.items(): + log.debug(f" Model: {model_name}") + exists_group = [] + not_exist_group = [] + + for timewindow, filepath in registry.forecasts.items(): + if registry.forecast_exists(timewindow): + exists_group.append(timewindow) + else: + not_exist_group.append(timewindow) + + log.debug(f" Existing forecasts: {len(exists_group)}") + log.debug(f" Missing forecasts: {len(not_exist_group)}") + for timewindow in not_exist_group: + log.debug(f" Time Window: {timewindow}") + log.debug("===================") + + +def log_results_tree(log, experiment_registry): + """ + Logs a summary of the results dictionary, sorted by test. + For each test and time window, it logs whether all models have results, + or if some results are missing, and specifies which models are missing. + """ + log.debug("===================") + + total_results = results_exist_count = results_not_exist_count = 0 + + # Get all unique test names and sort them + all_tests = sorted( + {test_name for tests in experiment_registry.results.values() for test_name in tests} + ) + + for test_name in all_tests: + log.debug(f"Test: {test_name}") + for timewindow, tests in experiment_registry.results.items(): + if test_name in tests: + models = tests[test_name] + missing_models = [] + + for model_name, result_path in models.items(): + total_results += 1 + result_full_path = experiment_registry.get_result_key(timewindow, test_name, model_name) + if os.path.exists(result_full_path): + results_exist_count += 1 + else: + results_not_exist_count += 1 + missing_models.append(model_name) + + if not missing_models: + log.debug(f" Time Window: {timewindow} - All models evaluated.") + else: + log.debug( + f" Time Window: {timewindow} - Missing results for models: " + f"{', '.join(missing_models)}" + ) + + log.debug(f"Total Results: {total_results}") + log.debug(f"Results that Exist: {results_exist_count}") + log.debug(f"Results that Do Not Exist: {results_not_exist_count}") + log.debug("===================") \ No newline at end of file diff --git a/floatcsep/infrastructure/registries.py b/floatcsep/infrastructure/registries.py index 66d2d52..f7eaff5 100644 --- a/floatcsep/infrastructure/registries.py +++ b/floatcsep/infrastructure/registries.py @@ -13,11 +13,12 @@ log = logging.getLogger("floatLogger") - -class FileRegistry(ABC): - - def __init__(self, workdir: str) -> None: - self.workdir = workdir +class FilepathMixin: + """ + Small mixin to provide filepath management functionality to Registries that uses files to + store objects + """ + workdir: str @staticmethod def _parse_arg(arg) -> Union[str, list[str]]: @@ -32,31 +33,60 @@ def _parse_arg(arg) -> Union[str, list[str]]: else: raise Exception("Arg is not found") - @abstractmethod - def as_dict(self) -> dict: - pass + def get_attr(self, *args: Sequence[str]) -> str: + """ + Access instance attributes and its contents (e.g., through dict keys) recursively in a + normalized function call. Returns the expected absolute path of this element - @abstractmethod - def build_tree(self, *args, **kwargs) -> None: - pass + Args: + *args: A sequence of keys (usually time-window strings) - @abstractmethod - def get(self, *args: Sequence[str]) -> Any: - pass + Returns: + The registry element (forecast, catalogs, etc.) from a sequence of key value + (usually time-window strings) as filepath + """ + + val = self.__dict__ + for i in args: + parsed_arg = self._parse_arg(i) + val = val[parsed_arg] + return self.abs(val) def abs(self, *paths: Sequence[str]) -> str: + """ + Returns the absolute path of an object, relative to the Registry workdir. + + Args: + *paths: + + Returns: + + """ _path = normpath(abspath(join(self.workdir, *paths))) return _path def abs_dir(self, *paths: Sequence[str]) -> str: + """ + Returns the absolute path of the directory containing an item relative to the Registry + workdir. + Args: + *paths: sequence of keys (usually time-window strings) + + Returns: + String describing the absolute directory + """ _path = normpath(abspath(join(self.workdir, *paths))) _dir = dirname(_path) return _dir def rel(self, *paths: Sequence[str]) -> str: - """Gets the relative path of a file, when it was defined relative to. + """ + Gets the relative path of an item, relative to the Registry workdir - the experiment working dir. + Args: + *paths: sequence of keys (usually time-window strings) + Returns: + String describing the relative path """ _abspath = normpath(abspath(join(self.workdir, *paths))) @@ -64,9 +94,14 @@ def rel(self, *paths: Sequence[str]) -> str: return _relpath def rel_dir(self, *paths: Sequence[str]) -> str: - """Gets the absolute path of a file, when it was defined relative to. + """ + Gets the relative path of the directory containing an item, relative to the Registry + workdir - the experiment working dir. + Args: + *paths: sequence of keys (usually time-window strings) + Returns: + String describing the relative path """ _path = normpath(abspath(join(self.workdir, *paths))) @@ -75,17 +110,43 @@ def rel_dir(self, *paths: Sequence[str]) -> str: return relpath(_dir, self.workdir) def file_exists(self, *args: Sequence[str]): - file_abspath = self.get(*args) + """ + Determine is such file exists in the filesystem + + Args: + *paths: sequence of keys (usually time-window strings) + Returns: + flag indicating if file exists + """ + file_abspath = self.get_attr(*args) return exists(file_abspath) +class ModelRegistry(ABC): + @abstractmethod + def get_input_catalog_key(self, tstring: str) -> str: + pass -class ForecastRegistry(FileRegistry): - """ - The class has the responsibility of managing the keys (based on timewindow strings) and path - structure of the forecast pertaining to a model (i.e., forecasts from different - time-windows), keeping track of the forecast existence and path in the filesystem. - """ + @abstractmethod + def get_forecast_key(self, tstring: str) -> str: + pass + + @abstractmethod + def get_args_key(self, tstring: str) -> str: + pass + + @classmethod + def factory(cls, registry_type: str = 'file', **kwargs) -> "ModelRegistry": + """Factory method. Instantiate first on any explicit option provided in the model + configuration. + """ + + if registry_type == 'file': + return ModelFileRegistry(**kwargs) + elif registry_type == 'hdf5': + return ModelHDF5Registry(**kwargs) + +class ModelFileRegistry(ModelRegistry, FilepathMixin): def __init__( self, workdir: str, @@ -104,55 +165,25 @@ def __init__( args_file (str): The path of the arguments file (only for TimeDependentModel). input_cat (str): : The path of the arguments file (only for TimeDependentModel). """ - super().__init__(workdir) + self.workdir = workdir self.path = path self.database = database self.args_file = args_file self.input_cat = input_cat self.forecasts = {} - self._fmt = fmt - def get(self, *args: Sequence[str]) -> str: - """ - Args: - *args: A sequence of keys (usually time-window strings) - - Returns: - The registry element (forecast, catalogs, etc.) from a sequence of key value - (usually time-window strings) - """ - - val = self.__dict__ - for i in args: - parsed_arg = self._parse_arg(i) - val = val[parsed_arg] - return self.abs(val) - - def get_forecast(self, *args: Sequence[str]) -> str: - """ - Gets the filepath of a forecast for a given sequence of keys (usually a timewindow - string). - - Args: - *args: A sequence of keys (usually time-window strings) - - Returns: - The forecast registry from a sequence of key values - """ - return self.get("forecasts", *args) - @property def dir(self) -> str: """ Returns: The directory containing the model source. """ - if os.path.isdir(self.get("path")): - return self.get("path") + if os.path.isdir(self.get_attr("path")): + return self.get_attr("path") else: - return os.path.dirname(self.get("path")) + return os.path.dirname(self.get_attr("path")) @property def fmt(self) -> str: @@ -170,24 +201,9 @@ def fmt(self) -> str: else: return self._fmt - def as_dict(self) -> dict: - """ - - Returns: - Simple dictionary serialization of the instance with the core attributes - """ - return { - "workdir": self.workdir, - "path": self.path, - "database": self.database, - "args_file": self.args_file, - "input_cat": self.input_cat, - "forecasts": self.forecasts, - } - def forecast_exists(self, timewindow: Union[str, list]) -> Union[bool, Sequence[bool]]: """ - Checks if forecasts exist for a sequence of timewindows + Checks if forecasts exist for a sequence of time_windows Args: timewindow (str, list): A single or sequence of strings representing a time window @@ -200,9 +216,48 @@ def forecast_exists(self, timewindow: Union[str, list]) -> Union[bool, Sequence[ else: return [self.file_exists("forecasts", i) for i in timewindow] + def get_input_catalog_key(self, *args: Sequence[str]) -> str: + """ + Gets the filepath of the input catalog for a given sequence of keys (usually a timewindow + string). + + Args: + *args: A sequence of keys (usually time-window strings) + + Returns: + The input catalog registry key from a sequence of key values + """ + return self.get_attr("input_cat", *args) + + def get_forecast_key(self, *args: Sequence[str]) -> str: + """ + Gets the filepath of a forecast for a given sequence of keys (usually a timewindow + string). + + Args: + *args: A sequence of keys (usually time-window strings) + + Returns: + The forecast registry from a sequence of key values + """ + return self.get_attr("forecasts", *args) + + def get_args_key(self, *args: Sequence[str]) -> str: + """ + Gets the filepath of an arguments file for a given sequence of keys (usually a timewindow + string). + + Args: + *args: A sequence of keys (usually time-window strings) + + Returns: + The argument file's key(s) from a sequence of key values + """ + return self.get_attr("args_file", *args) + def build_tree( self, - timewindows: Sequence[Sequence[datetime]] = None, + time_windows: Sequence[Sequence[datetime]] = None, model_class: str = "TimeIndependentModel", prefix: str = None, args_file: str = None, @@ -212,7 +267,7 @@ def build_tree( Creates the run directory, and reads the file structure inside. Args: - timewindows (list(str)): List of time windows or strings. + time_windows (list(str)): List of time windows or strings. model_class (str): Model's class name prefix (str): prefix of the model forecast filenames if TD args_file (str, bool): input arguments path of the model if TD @@ -221,7 +276,7 @@ def build_tree( """ - windows = timewindow2str(timewindows) + windows = timewindow2str(time_windows) if model_class == "TimeIndependentModel": fname = self.database if self.database else self.path @@ -246,27 +301,72 @@ def build_tree( win: join(dirtree["forecasts"], f"{prefix}_{win}.{self.fmt}") for win in windows } - def log_tree(self) -> None: + def as_dict(self) -> dict: """ - Logs a grouped summary of the forecasts' dictionary. - Groups time windows by whether the forecast exists or not. + + Returns: + Simple dictionary serialization of the instance with the core attributes """ - exists_group = [] - not_exist_group = [] + return { + "workdir": self.workdir, + "path": self.path, + "database": self.database, + "args_file": self.args_file, + "input_cat": self.input_cat, + "forecasts": self.forecasts, + } - for timewindow, filepath in self.forecasts.items(): - if self.forecast_exists(timewindow): - exists_group.append(timewindow) - else: - not_exist_group.append(timewindow) +class ModelHDF5Registry(ModelRegistry): - log.debug(f" Existing forecasts: {len(exists_group)}") - log.debug(f" Missing forecasts: {len(not_exist_group)}") - for timewindow in not_exist_group: - log.debug(f" Time Window: {timewindow}") + def __init__(self, workdir: str, path: str): + pass + def get_input_catalog_key(self, tstring: str) -> str: + return '' + def get_forecast_key(self, tstring: str) -> str: + return '' + def get_args_key(self, tstring: str) -> str: + return '' + +class ExperimentRegistry(ABC): + @abstractmethod + def add_model_registry(self, model: "Model") -> None: + pass + @abstractmethod + def get_model_registry(self, model_name: str) -> ModelRegistry: + pass + + @abstractmethod + def get_result_key(self, test_name: str, model_name: str, tstring: str) -> str: + pass -class ExperimentRegistry(FileRegistry): + @abstractmethod + def get_figure_key(self, test_name: str, model_name: str, tstring: str) -> str: + pass + + @abstractmethod + def get_test_catalog_key(self, tstring: str) -> str: + pass + + @abstractmethod + def build_tree( + self, + time_windows: Sequence[Sequence[datetime]], + models: Sequence["Model"], + tests: Sequence["Evaluation"], + ) -> None: + pass + + @classmethod + def factory(cls, registry_type: str = 'file', **kwargs) -> "ExperimentRegistry": + """Factory method. Instantiate first on any explicit option provided in the experiment + configuration. + """ + + if registry_type == 'file': + return ExperimentFileRegistry(**kwargs) + +class ExperimentFileRegistry(ExperimentRegistry, FilepathMixin): """ The class has the responsibility of managing the keys (based on models, timewindow and evaluation name strings) to the structure of the experiment inputs (catalogs, models etc) @@ -281,96 +381,97 @@ def __init__(self, workdir: str, run_dir: str = "results") -> None: workdir: The working directory for the experiment run-time. run_dir: The directory in which the results will be stored. """ - super().__init__(workdir) + self.workdir = workdir self.run_dir = run_dir self.results = {} self.test_catalogs = {} self.figures = {} self.repr_config = "repr_config.yml" - self.forecast_registries = {} + self.model_registries = {} + + def get_attr(self, *args: Any) -> str: + """ + Args: + *args: A sequence of keys (usually models, tests and/or time-window strings) + + Returns: + The filepath from a sequence of key values (usually models first, then time-window + strings) + """ + val = self.__dict__ + for i in args: + parsed_arg = self._parse_arg(i) + val = val[parsed_arg] + return self.abs(self.run_dir, val) - def add_forecast_registry(self, model: "Model") -> None: + def add_model_registry(self, model: "Model") -> None: """ - Adds a model's ForecastRegistry to the ExperimentRegistry. + Adds a model's ForecastRegistry to the ExperimentFileRegistry. Args: model (str): A Model object """ - self.forecast_registries[model.name] = model.registry + self.model_registries[model.name] = model.registry - def get_forecast_registry(self, model_name: str) -> None: + def get_model_registry(self, model_name: str) -> None: """ - Retrieves a model's ForecastRegistry from the ExperimentRegistry. + Retrieves a model's ForecastRegistry from the ExperimentFileRegistry. Args: model_name (str): The name of the model. Returns: - ForecastRegistry: The ForecastRegistry associated with the model. + ModelRegistry: The ModelRegistry associated with the model. """ - return self.forecast_registries.get(model_name) + return self.model_registries.get(model_name) - def log_forecast_trees(self, timewindows: list) -> None: - """ - Logs the forecasts for all models managed by this ExperimentRegistry. + def result_exist(self, timewindow_str: str, test_name: str, model_name: str) -> bool: """ - log.debug("===================") - log.debug(f" Total Time Windows: {len(timewindows)}") - for model_name, registry in self.forecast_registries.items(): - log.debug(f" Model: {model_name}") - registry.log_tree() - log.debug("===================") + Checks if a given test results exist - def get(self, *args: Any) -> str: - """ Args: - *args: A sequence of keys (usually models, tests and/or time-window strings) + timewindow_str (str): String representing the time window + test_name (str): Name of the evaluation + model_name (str): Name of the model - Returns: - The filepath from a sequence of key values (usually models first, then time-window - strings) """ - val = self.__dict__ - for i in args: - parsed_arg = self._parse_arg(i) - val = val[parsed_arg] - return self.abs(self.run_dir, val) + return self.file_exists("results", timewindow_str, test_name, model_name) - def get_result(self, *args: Sequence[any]) -> str: + def get_test_catalog_key(self, *args: Sequence[any]) -> str: """ - Gets the file path of an evaluation result. + Gets the file path of a testing catalog. Args: - args: A sequence of keys (usually models, tests and/or time-window strings) + *args: A sequence of keys (time-window strings) Returns: - The filepath of a serialized result + The filepath of the testing catalog for a given time-window """ - val = self.results + val = self.test_catalogs for i in args: parsed_arg = self._parse_arg(i) val = val[parsed_arg] return self.abs(self.run_dir, val) - def get_test_catalog(self, *args: Sequence[any]) -> str: + def get_result_key(self, *args: Sequence[any]) -> str: """ - Gets the file path of a testing catalog. + Gets the file path of an evaluation result. Args: - *args: A sequence of keys (time-window strings) + args: A sequence of keys (usually models, tests and/or time-window strings) Returns: - The filepath of the testing catalog for a given time-window + The filepath of a serialized result """ - val = self.test_catalogs + val = self.results for i in args: parsed_arg = self._parse_arg(i) val = val[parsed_arg] return self.abs(self.run_dir, val) - def get_figure(self, *args: Sequence[any]) -> str: + def get_figure_key(self, *args: Sequence[any]) -> str: """ Gets the file path of a result figure. @@ -386,25 +487,9 @@ def get_figure(self, *args: Sequence[any]) -> str: val = val[parsed_arg] return self.abs(self.run_dir, val) - def result_exist(self, timewindow_str: str, test_name: str, model_name: str) -> bool: - """ - Checks if a given test results exist - - Args: - timewindow_str (str): String representing the time window - test_name (str): Name of the evaluation - model_name (str): Name of the model - - """ - return self.file_exists("results", timewindow_str, test_name, model_name) - - def as_dict(self) -> str: - # todo: rework - return self.workdir - def build_tree( self, - timewindows: Sequence[Sequence[datetime]], + time_windows: Sequence[Sequence[datetime]], models: Sequence["Model"], tests: Sequence["Evaluation"], ) -> None: @@ -412,12 +497,12 @@ def build_tree( Creates the run directory and reads the file structure inside. Args: - timewindows: List of time windows, or representing string. + time_windows: List of time windows, or representing string. models: List of models or model names tests: List of tests or test names """ - windows = timewindow2str(timewindows) + windows = timewindow2str(time_windows) models = [i.name for i in models] tests = [i.name for i in tests] @@ -465,46 +550,7 @@ def build_tree( self.test_catalogs = test_catalogs self.figures = figures - def log_results_tree(self): - """ - Logs a summary of the results dictionary, sorted by test. - For each test and time window, it logs whether all models have results, - or if some results are missing, and specifies which models are missing. - """ - log.debug("===================") - - total_results = results_exist_count = results_not_exist_count = 0 - - # Get all unique test names and sort them - all_tests = sorted( - {test_name for tests in self.results.values() for test_name in tests} - ) - - for test_name in all_tests: - log.debug(f"Test: {test_name}") - for timewindow, tests in self.results.items(): - if test_name in tests: - models = tests[test_name] - missing_models = [] - - for model_name, result_path in models.items(): - total_results += 1 - result_full_path = self.get_result(timewindow, test_name, model_name) - if os.path.exists(result_full_path): - results_exist_count += 1 - else: - results_not_exist_count += 1 - missing_models.append(model_name) - - if not missing_models: - log.debug(f" Time Window: {timewindow} - All models evaluated.") - else: - log.debug( - f" Time Window: {timewindow} - Missing results for models: " - f"{', '.join(missing_models)}" - ) - - log.debug(f"Total Results: {total_results}") - log.debug(f"Results that Exist: {results_exist_count}") - log.debug(f"Results that Do Not Exist: {results_not_exist_count}") - log.debug("===================") + def as_dict(self) -> str: + + return self.workdir + diff --git a/floatcsep/infrastructure/repositories.py b/floatcsep/infrastructure/repositories.py index c485d26..572d90b 100644 --- a/floatcsep/infrastructure/repositories.py +++ b/floatcsep/infrastructure/repositories.py @@ -12,8 +12,8 @@ from csep.models import EvaluationResult from csep.utils.time_utils import decimal_year -from floatcsep.utils.readers import ForecastParsers -from floatcsep.infrastructure.registries import ForecastRegistry, ExperimentRegistry +from floatcsep.utils.readers import GriddedForecastParsers, CatalogForecastParsers +from floatcsep.infrastructure.registries import ExperimentRegistry, ModelRegistry from floatcsep.utils.helpers import str2timewindow, parse_csep_func from floatcsep.utils.helpers import timewindow2str @@ -27,7 +27,7 @@ class ForecastRepository(ABC): @abstractmethod - def __init__(self, registry: ForecastRegistry): + def __init__(self, registry: ModelRegistry): self.registry = registry self.lazy_load = False self.forecasts = {} @@ -61,7 +61,7 @@ def __eq__(self, other) -> bool: @classmethod def factory( - cls, registry: ForecastRegistry, model_class: str, forecast_type: str = None, **kwargs + cls, registry: ModelRegistry, model_class: str, forecast_type: str = None, **kwargs ) -> "ForecastRepository": """Factory method. Instantiate first on explicit option provided in the model configuration. Then, defaults to gridded forecast for TimeIndependentModel and catalog @@ -89,11 +89,11 @@ class CatalogForecastRepository(ForecastRepository): """ - def __init__(self, registry: ForecastRegistry, **kwargs): + def __init__(self, registry: ModelRegistry, **kwargs): """ Args: - registry (ForecastRegistry): The registry containing the keys/path to the forecasts + registry (ModelRegistry): The registry containing the keys/path to the forecasts given their time-windows. **kwargs: """ @@ -102,7 +102,7 @@ def __init__(self, registry: ForecastRegistry, **kwargs): self.forecasts = {} def load_forecast( - self, tstring: Union[str, list], region=None + self, tstring: Union[str, list], region=None, n_sims=None, ) -> Union[CatalogForecast, list[CatalogForecast]]: """ Returns a forecast object or a sequence of them for a set of time window strings. @@ -110,20 +110,31 @@ def load_forecast( Args: tstring (str, list): String representing the time-window region (optional): A region, in case the forecast requires to be filtered lazily. + n_sims (optional: The number of simulations/synthetic catalogs of the forecast Returns: The CSEP CatalogForecast object or a list of them. """ if isinstance(tstring, str): - return self._load_single_forecast(tstring, region) + return self._load_single_forecast(tstring, region=region, n_sims=n_sims) else: return [self._load_single_forecast(t, region) for t in tstring] - def _load_single_forecast(self, t: str, region=None): - fc_path = self.registry.get_forecast(t) - return csep.load_catalog_forecast( - fc_path, region=region, apply_filters=True, filter_spatial=True - ) + def _load_single_forecast(self, tstring: str, region=None, n_sims=None): + start_date, end_date = str2timewindow(tstring) + + fc_path = self.registry.get_forecast_key(tstring) + f_parser = getattr(CatalogForecastParsers, self.registry.fmt) + + forecast_ = f_parser(fc_path, + start_time=start_date, + end_time=end_date, + n_cat=n_sims, + region=region, + apply_filters=True, + filter_spatial=True, + ) + return forecast_ def remove(self, tstring: Union[str, Sequence[str]]): pass @@ -136,11 +147,11 @@ class GriddedForecastRepository(ForecastRepository): avoid parsing files repeatedly (Skip for large files). """ - def __init__(self, registry: ForecastRegistry, **kwargs): + def __init__(self, registry: ModelRegistry, **kwargs): """ Args: - registry (ForecastRegistry): The registry containing the keys/path to the forecasts + registry (ModelRegistry): The registry containing the keys/path to the forecasts given their time-windows. **kwargs: """ @@ -189,8 +200,8 @@ def _load_single_forecast(self, tstring: str, fc_unit: float = 1, name_=""): time_horizon = decimal_year(end_date) - decimal_year(start_date) tstring_ = timewindow2str([start_date, end_date]) - f_path = self.registry.get_forecast(tstring_) - f_parser = getattr(ForecastParsers, self.registry.fmt) + f_path = self.registry.get_forecast_key(tstring_) + f_parser = getattr(GriddedForecastParsers, self.registry.fmt) rates, region, mags = f_parser(f_path) @@ -243,7 +254,7 @@ def _load_result( else: wstr_ = window - eval_path = self.registry.get_result(wstr_, test, model) + eval_path = self.registry.get_result_key(wstr_, test, model) with open(eval_path, "r") as file_: model_eval = EvaluationResult.from_dict(json.load(file_)) @@ -287,7 +298,7 @@ def write_result(self, result: EvaluationResult, test, model, window) -> None: window: Name of the time-window """ - path = self.registry.get_result(window, test, model) + path = self.registry.get_result_key(window, test, model) class NumpyEncoder(json.JSONEncoder): def default(self, obj): @@ -381,8 +392,8 @@ def catalog(self) -> CSEPCatalog: if isfile(self.cat_path): return CSEPCatalog.load_json(self.cat_path) bounds = { - "start_time": min([item for sublist in self.timewindows for item in sublist]), - "end_time": max([item for sublist in self.timewindows for item in sublist]), + "start_time": min([item for sublist in self.time_windows for item in sublist]), + "end_time": max([item for sublist in self.time_windows for item in sublist]), "min_magnitude": self.magnitudes.min(), "max_depth": self.depths.max(), } @@ -471,7 +482,7 @@ def set_test_cat(self, tstring: str) -> None: tstring (str): Time window string """ - testcat_name = self.registry.get_test_catalog(tstring) + testcat_name = self.registry.get_test_catalog_key(tstring) if not exists(testcat_name): log.debug( f"Filtering testing catalog and saving to {self.registry.rel(testcat_name)}" @@ -504,4 +515,4 @@ def set_input_cat(self, tstring: str, model: "Model") -> None: """ start, end = str2timewindow(tstring) sub_cat = self.catalog.filter([f"origin_time < {start.timestamp() * 1000}"]) - sub_cat.write_ascii(filename=model.registry.get("input_cat")) + sub_cat.write_ascii(filename=model.registry.get_input_catalog_key()) diff --git a/floatcsep/model.py b/floatcsep/model.py index 0993a55..5d3b430 100644 --- a/floatcsep/model.py +++ b/floatcsep/model.py @@ -11,8 +11,8 @@ from floatcsep.utils.accessors import from_zenodo, from_git from floatcsep.infrastructure.environments import EnvironmentFactory -from floatcsep.utils.readers import ForecastParsers, HDF5Serializer -from floatcsep.infrastructure.registries import ForecastRegistry +from floatcsep.utils.readers import GriddedForecastParsers, HDF5Serializer +from floatcsep.infrastructure.registries import ModelRegistry from floatcsep.infrastructure.repositories import ForecastRepository from floatcsep.utils.helpers import timewindow2str, str2timewindow, parse_nested_dicts @@ -60,7 +60,7 @@ def __init__( self.__dict__.update(**kwargs) @abstractmethod - def stage(self, timewindows=None) -> None: + def stage(self, time_windows=None) -> None: """Prepares the stage for a model run.""" pass @@ -116,7 +116,7 @@ def get_source(self, zenodo_id: int = None, giturl: str = None, **kwargs) -> Non raise FileNotFoundError("Model has no path or identified") if not os.path.exists(self.registry.dir) or not os.path.exists( - self.registry.get("path") + self.registry.get_attr("path") ): raise FileNotFoundError( f"Directory '{self.registry.dir}' or file {self.registry}' do not exist. " @@ -210,18 +210,19 @@ def __init__(self, name: str, model_path: str, forecast_unit=1, store_db=False, self.forecast_unit = forecast_unit self.store_db = store_db - self.registry = ForecastRegistry(kwargs.get("workdir", os.getcwd()), model_path) + self.registry = ModelRegistry.factory(workdir=kwargs.get("workdir", os.getcwd()), + path=model_path) self.repository = ForecastRepository.factory( self.registry, model_class=self.__class__.__name__, **kwargs ) - def stage(self, timewindows: Sequence[Sequence[datetime]] = None) -> None: + def stage(self, time_windows: Sequence[Sequence[datetime]] = None) -> None: """ Acquire the forecast data if it is not in the file system. Sets the paths internally (or database pointers) to the forecast data. Args: - timewindows (list): time_windows that the forecast data represents. + time_windows (list): time_windows that the forecast data represents. """ if self.force_stage or not self.registry.file_exists("path"): @@ -231,7 +232,7 @@ def stage(self, timewindows: Sequence[Sequence[datetime]] = None) -> None: if self.store_db: self.init_db() - self.registry.build_tree(timewindows=timewindows, model_class=self.__class__.__name__) + self.registry.build_tree(time_windows=time_windows, model_class=self.__class__.__name__) def init_db(self, dbpath: str = "", force: bool = False) -> None: """ @@ -246,8 +247,8 @@ def init_db(self, dbpath: str = "", force: bool = False) -> None: exists """ - parser = getattr(ForecastParsers, self.registry.fmt) - rates, region, mag = parser(self.registry.get("path")) + parser = getattr(GriddedForecastParsers, self.registry.fmt) + rates, region, mag = parser(self.registry.get_attr("path")) db_func = HDF5Serializer.grid2hdf5 if not dbpath: @@ -320,9 +321,9 @@ def __init__( self.func = func self.func_kwargs = func_kwargs or {} - self.registry = ForecastRegistry(workdir=kwargs.get("workdir", os.getcwd()), - path=model_path, - fmt=fmt) + self.registry = ModelRegistry.factory(workdir=kwargs.get("workdir", os.getcwd()), + path=model_path, + fmt=fmt) self.repository = ForecastRepository.factory( self.registry, model_class=self.__class__.__name__, **kwargs ) @@ -333,7 +334,7 @@ def __init__( self.build, self.name, self.registry.abs(model_path) ) - def stage(self, timewindows=None) -> None: + def stage(self, time_windows=None) -> None: """ Core method to interface a model with the experiment. @@ -351,7 +352,7 @@ def stage(self, timewindows=None) -> None: self.environment.create_environment(force=self.force_build) self.registry.build_tree( - timewindows=timewindows, + time_windows=time_windows, model_class=self.__class__.__name__, prefix=self.__dict__.get("prefix", self.name), args_file=self.__dict__.get("args_file", None), @@ -366,7 +367,7 @@ def get_forecast( Note: The argument ``tstring`` is formatted according to how the Experiment - handles timewindows, specified in the functions + handles time_windows, specified in the functions :func:`~floatcsep.utils.helpers.timewindow2str` and :func:`~floatcsep.utils.helpers.str2timewindow` @@ -385,7 +386,7 @@ def create_forecast(self, tstring: str, **kwargs) -> None: Note: The argument ``tstring`` is formatted according to how the Experiment - handles timewindows, specified in the functions + handles time_windows, specified in the functions :func:`~floatcsep.utils.helpers.timewindow2str` and :func:`~floatcsep.utils.helpers.str2timewindow` @@ -406,7 +407,7 @@ def create_forecast(self, tstring: str, **kwargs) -> None: f"Running {self.name} using {self.environment.__class__.__name__}:" f" {timewindow2str([start_date, end_date])}" ) - self.environment.run_command(f"{self.func} {self.registry.get('args_file')}") + self.environment.run_command(f"{self.func} {self.registry.get_args_key()}") def prepare_args(self, start: datetime, end: datetime, **kwargs) -> None: """ @@ -421,7 +422,7 @@ def prepare_args(self, start: datetime, end: datetime, **kwargs) -> None: **kwargs: represents additional model arguments (name/value pair) """ - filepath = self.registry.get("args_file") + filepath = self.registry.get_args_key() fmt = os.path.splitext(filepath)[1] if fmt == ".txt": diff --git a/floatcsep/postprocess/plot_handler.py b/floatcsep/postprocess/plot_handler.py index 18399d6..2de233f 100644 --- a/floatcsep/postprocess/plot_handler.py +++ b/floatcsep/postprocess/plot_handler.py @@ -27,10 +27,10 @@ def plot_results(experiment: "Experiment") -> None: """ log.info("Plotting evaluation results") - timewindows = timewindow2str(experiment.timewindows) + time_windows = timewindow2str(experiment.time_windows) for test in experiment.tests: - test.plot_results(timewindows, experiment.models, experiment.registry) + test.plot_results(time_windows, experiment.models, experiment.registry) def plot_forecasts(experiment: "Experiment") -> None: @@ -76,9 +76,9 @@ def plot_forecasts(experiment: "Experiment") -> None: # Get the time windows to be plotted. Defaults to only the last time window. time_windows = ( - timewindow2str(experiment.timewindows) + timewindow2str(experiment.time_windows) if plot_forecast_config.get("all_time_windows") - else [timewindow2str(experiment.timewindows[-1])] + else [timewindow2str(experiment.time_windows[-1])] ) # Get the projection of the plots @@ -106,7 +106,7 @@ def plot_forecasts(experiment: "Experiment") -> None: } ), ) - fig_path = experiment.registry.get_figure(window, "forecasts", model.name) + fig_path = experiment.registry.get_figure_key(window, "forecasts", model.name) pyplot.savefig(fig_path, dpi=plot_forecast_config.get("dpi", 300)) @@ -167,17 +167,17 @@ def plot_catalogs(experiment: "Experiment") -> None: # Plot catalog map ax = main_catalog.plot(plot_args=plot_catalog_config) - cat_map_path = experiment.registry.get_figure("main_catalog_map") + cat_map_path = experiment.registry.get_figure_key("main_catalog_map") ax.get_figure().savefig(cat_map_path, dpi=plot_catalog_config.get("dpi", 300)) # Plot catalog time series vs. magnitude ax = magnitude_vs_time(main_catalog) - cat_time_path = experiment.registry.get_figure("main_catalog_time") + cat_time_path = experiment.registry.get_figure_key("main_catalog_time") ax.get_figure().savefig(cat_time_path, dpi=plot_catalog_config.get("dpi", 300)) # If selected, plot the test catalogs for each of the time windows if plot_catalog_config.get("all_time_windows"): - for tw in experiment.timewindows: + for tw in experiment.time_windows: test_catalog = experiment.catalog_repo.get_test_cat(timewindow2str(tw)) if test_catalog.get_number_of_events() != 0: @@ -185,11 +185,11 @@ def plot_catalogs(experiment: "Experiment") -> None: continue ax = test_catalog.plot(plot_args=plot_catalog_config) - cat_map_path = experiment.registry.get_figure(tw, "catalog_map") + cat_map_path = experiment.registry.get_figure_key(tw, "catalog_map") ax.get_figure().savefig(cat_map_path, dpi=plot_catalog_config.get("dpi", 300)) ax = magnitude_vs_time(test_catalog) - cat_time_path = experiment.registry.get_figure(tw, "catalog_time") + cat_time_path = experiment.registry.get_figure_key(tw, "catalog_time") ax.get_figure().savefig(cat_time_path, dpi=plot_catalog_config.get("dpi", 300)) diff --git a/floatcsep/postprocess/reporting.py b/floatcsep/postprocess/reporting.py index 7ec0e5d..cacae0c 100644 --- a/floatcsep/postprocess/reporting.py +++ b/floatcsep/postprocess/reporting.py @@ -35,7 +35,7 @@ def generate_report(experiment, timewindow=-1): custom_report(report_function, experiment) return - timewindow = experiment.timewindows[timewindow] + timewindow = experiment.time_windows[timewindow] timestr = timewindow2str(timewindow) log.info(f"Saving report into {experiment.registry.run_dir}") @@ -60,11 +60,11 @@ def generate_report(experiment, timewindow=-1): "Input catalog", [ os.path.relpath( - experiment.registry.get_figure("main_catalog_map"), + experiment.registry.get_figure_key("main_catalog_map"), experiment.registry.run_dir, ), os.path.relpath( - experiment.registry.get_figure("main_catalog_time"), + experiment.registry.get_figure_key("main_catalog_time"), experiment.registry.run_dir, ), ], @@ -81,7 +81,7 @@ def generate_report(experiment, timewindow=-1): # Include results from Experiment for test in experiment.tests: - fig_path = experiment.registry.get_figure(timestr, test) + fig_path = experiment.registry.get_figure_key(timestr, test) width = test.plot_args[0].get("figsize", [4])[0] * 96 report.add_figure( f"{test.name}", @@ -93,7 +93,7 @@ def generate_report(experiment, timewindow=-1): ) for model in experiment.models: try: - fig_path = experiment.registry.get_figure(timestr, f"{test.name}_{model.name}") + fig_path = experiment.registry.get_figure_key(timestr, f"{test.name}_{model.name}") width = test.plot_args[0].get("figsize", [4])[0] * 96 report.add_figure( f"{test.name}: {model.name}", diff --git a/floatcsep/utils/helpers.py b/floatcsep/utils/helpers.py index 868e865..c461b2c 100644 --- a/floatcsep/utils/helpers.py +++ b/floatcsep/utils/helpers.py @@ -75,7 +75,9 @@ def _getattr(obj_, attr_): floatcsep.utils.helpers, floatcsep.utils.accessors, floatcsep.utils.readers.HDF5Serializer, - floatcsep.utils.readers.ForecastParsers, + floatcsep.utils.readers.GriddedForecastParsers, + floatcsep.utils.readers.CatalogForecastParsers, + ] for module in _target_modules: try: @@ -158,14 +160,14 @@ def read_time_cfg(time_config, **kwargs): if "offset" in time_config.keys(): time_config["offset"] = parse_timedelta_string(time_config["offset"]) - if not time_config.get("timewindows"): + if not time_config.get("time_windows"): if experiment_class == "ti": - time_config["timewindows"] = timewindows_ti(**time_config) + time_config["time_windows"] = time_windows_ti(**time_config) elif experiment_class == "td": - time_config["timewindows"] = timewindows_td(**time_config) + time_config["time_windows"] = time_windows_td(**time_config) else: - time_config["start_date"] = time_config["timewindows"][0][0] - time_config["end_date"] = time_config["timewindows"][-1][-1] + time_config["start_date"] = time_config["time_windows"][0][0] + time_config["end_date"] = time_config["time_windows"][-1][-1] return time_config @@ -242,7 +244,7 @@ def timewindow2str(datetimes: Sequence) -> Union[str, list[str]]: single timewindow or a list of time windows. Args: - datetimes: A sequence (of sequences) of datetimes, representing a list of timewindows + datetimes: A sequence (of sequences) of datetimes, representing a list of time_windows Returns: A sequence of strings for each time window @@ -279,7 +281,7 @@ def str2timewindow( return datetimes -def timewindows_ti( +def time_windows_ti( start_date=None, end_date=None, intervals=None, horizon=None, growth="incremental", **_ ): """ @@ -336,7 +338,7 @@ def timewindows_ti( return [(timelimits[0], i) for i in timelimits[1:]] -def timewindows_td( +def time_windows_td( start_date=None, end_date=None, timeintervals=None, timehorizon=None, timeoffset=None, **_ ): """ diff --git a/floatcsep/utils/readers.py b/floatcsep/utils/readers.py index aa02804..dbe7ea0 100644 --- a/floatcsep/utils/readers.py +++ b/floatcsep/utils/readers.py @@ -1,19 +1,173 @@ import argparse +import csv import logging import os.path import time import xml.etree.ElementTree as eTree +import csep import h5py import numpy import pandas +import pandas as pd +from csep.core.catalogs import CSEPCatalog from csep.core.regions import QuadtreeGrid2D, CartesianGrid2D from csep.models import Polygon +from csep.utils.time_utils import strptime_to_utc_epoch log = logging.getLogger(__name__) +class CatalogForecastParsers: -class ForecastParsers: + @staticmethod + def csv(filename, **kwargs): + csep_headers = ['lon', 'lat', 'magnitude', 'time_string', 'depth', 'catalog_id', + 'event_id'] + hermes_headers = ['realization_id', 'magnitude', 'depth', 'latitude', 'longitude', + 'time'] + headers_df = pd.read_csv(filename, nrows=0).columns.str.strip().to_list() + + # CSEP headers + if headers_df[:2] == csep_headers[:2]: + + return csep.load_catalog_forecast(filename, **kwargs) + + elif headers_df == hermes_headers: + return csep.load_catalog_forecast(filename, + catalog_loader=CatalogForecastParsers.load_hermes_catalog, + **kwargs + ) + else: + raise Exception('Catalog Forecast could not be loaded') + + @staticmethod + def load_hermes_catalog(filename, **kwargs): + """ Loads hermes synthetic catalogs in csep-ascii format. + + This function can load multiple catalogs stored in a single file. This typically called to + load a catalog-based forecast, but could also load a collection of catalogs stored in the same file + + Args: + filename (str): filepath or directory of catalog files + **kwargs (dict): passed to class constructor + + Return: + yields CSEPCatalog class + """ + + def read_float(val): + """Returns val as float or None if unable""" + try: + val = float(val) + except: + val = None + return val + + def is_header_line(line): + if line[0].lower() == 'realization_id': + return True + else: + return False + + def read_catalog_line(line): + # convert to correct types + + catalog_id = int(line[0]) + magnitude = read_float(line[1]) + depth = read_float(line[2]) + lat = read_float(line[3]) + lon = read_float(line[4]) + # maybe fractional seconds are not included + origin_time = line[5] + if origin_time: + try: + origin_time = strptime_to_utc_epoch(origin_time, + format='%Y-%m-%d %H:%M:%S.%f') + except ValueError: + origin_time = strptime_to_utc_epoch(origin_time, + format='%Y-%m-%d %H:%M:%S') + + event_id = 0 + # temporary event + temp_event = (event_id, origin_time, lat, lon, depth, magnitude) + return temp_event, catalog_id + + # handle all catalogs in single file + if os.path.isfile(filename): + with open(filename, 'r', newline='') as input_file: + catalog_reader = csv.reader(input_file, delimiter=',') + # csv treats everything as a string convert to correct types + events = [] + # all catalogs should start at zero + prev_id = None + for line in catalog_reader: + # skip header line on first read if included in file + if prev_id is None: + if is_header_line(line): + continue + # read line and return catalog id + temp_event, catalog_id = read_catalog_line(line) + empty = False + # OK if event_id is empty + if all([val in (None, '') for val in temp_event[1:]]): + empty = True + # first event is when prev_id is none, catalog_id should always start at zero + if prev_id is None: + prev_id = 0 + # if the first catalog doesn't start at zero + if catalog_id != prev_id: + if not empty: + events = [temp_event] + else: + events = [] + for id in range(catalog_id): + yield CSEPCatalog(data=[], catalog_id=id, **kwargs) + prev_id = catalog_id + continue + # accumulate event if catalog_id is the same as previous event + if catalog_id == prev_id: + if not all([val in (None, '') for val in temp_event]): + events.append(temp_event) + prev_id = catalog_id + # create and yield class if the events are from different catalogs + elif catalog_id == prev_id + 1: + yield CSEPCatalog(data=events, catalog_id=prev_id, **kwargs) + # add event to new event list + if not empty: + events = [temp_event] + else: + events = [] + prev_id = catalog_id + # this implies there are empty catalogs, because they are not listed in the ascii file + elif catalog_id > prev_id + 1: + yield CSEPCatalog(data=events, catalog_id=prev_id, **kwargs) + # if prev_id = 0 and catalog_id = 2, then we skipped one catalog. thus, we skip catalog_id - prev_id - 1 catalogs + num_empty_catalogs = catalog_id - prev_id - 1 + # first yield empty catalog classes + for id in range(num_empty_catalogs): + yield CSEPCatalog(data=[], + catalog_id=catalog_id - num_empty_catalogs + id, + **kwargs) + prev_id = catalog_id + # add event to new event list + if not empty: + events = [temp_event] + else: + events = [] + else: + raise ValueError( + "catalog_id should be monotonically increasing and events should be ordered by catalog_id") + # yield final catalog, note: since this is just loading catalogs, it has no idea how many should be there + cat = CSEPCatalog(data=events, catalog_id=prev_id, **kwargs) + yield cat + + elif os.path.isdir(filename): + raise NotImplementedError( + "reading from directory or batched files not implemented yet!") + + + +class GriddedForecastParsers: @staticmethod def dat(filename): @@ -151,7 +305,7 @@ def is_mag(num): sep = " " if "tile" in line: - rates, region, magnitudes = ForecastParsers.quadtree(filename) + rates, region, magnitudes = GriddedForecastParsers.quadtree(filename) return rates, region, magnitudes data = pandas.read_csv( @@ -308,13 +462,13 @@ def serialize(): args = parser.parse_args() if args.format == "quadtree": - ForecastParsers.quadtree(args.filename) + GriddedForecastParsers.quadtree(args.filename) if args.format == "dat": - ForecastParsers.dat(args.filename) + GriddedForecastParsers.dat(args.filename) if args.format == "csep" or args.format == "csv": - ForecastParsers.csv(args.filename) + GriddedForecastParsers.csv(args.filename) if args.format == "xml": - ForecastParsers.xml(args.filename) + GriddedForecastParsers.xml(args.filename) if __name__ == "__main__": diff --git a/tests/artifacts/models/td_model/forecasts/mock_2020-01-01_2020-01-02.csv b/tests/artifacts/models/td_model/forecasts/mock_2020-01-01_2020-01-02.csv index fdddae6..c05baeb 100644 --- a/tests/artifacts/models/td_model/forecasts/mock_2020-01-01_2020-01-02.csv +++ b/tests/artifacts/models/td_model/forecasts/mock_2020-01-01_2020-01-02.csv @@ -1,2 +1,2 @@ -lon, lat, M, time_string, depth, catalog_id, event_id +lon,lat,M,time_string,depth,catalog_id,event_id 1.0,1.0,5.0,2020-01-01T01:01:01.0,10.0,1,1 \ No newline at end of file diff --git a/tests/artifacts/models/td_model/forecasts/mock_2020-01-02_2020-01-03.csv b/tests/artifacts/models/td_model/forecasts/mock_2020-01-02_2020-01-03.csv index f89633b..ad21ed4 100644 --- a/tests/artifacts/models/td_model/forecasts/mock_2020-01-02_2020-01-03.csv +++ b/tests/artifacts/models/td_model/forecasts/mock_2020-01-02_2020-01-03.csv @@ -1,2 +1,2 @@ -lon, lat, M, time_string, depth, catalog_id, event_id +lon,lat,M,time_string,depth,catalog_id,event_id 1.0,1.0,5.0,2020-01-02T01:01:01.0,10.0,1,1 \ No newline at end of file diff --git a/tests/qa/test_data.py b/tests/e2e/test_data.py similarity index 100% rename from tests/qa/test_data.py rename to tests/e2e/test_data.py diff --git a/tests/integration/test_model_accessors.py b/tests/integration/test_model_accessors.py index 5c940f9..bf27695 100644 --- a/tests/integration/test_model_accessors.py +++ b/tests/integration/test_model_accessors.py @@ -113,7 +113,7 @@ def init_model(name, model_path, **kwargs): return model @patch.object(EnvironmentManager, "create_environment") - @patch("floatcsep.infrastructure.registries.ForecastRegistry.build_tree") + @patch("floatcsep.infrastructure.registries.ModelFileRegistry.build_tree") def test_from_git(self, mock_build_tree, mock_create_environment): """clones model from git, checks with test artifacts""" mock_build_tree.return_value = None @@ -169,7 +169,7 @@ def init_model(name, model_path, **kwargs): model = TimeIndependentModel(name=name, model_path=model_path, **kwargs) return model - @patch("floatcsep.infrastructure.registries.ForecastRegistry.build_tree") + @patch("floatcsep.infrastructure.registries.ModelFileRegistry.build_tree") def test_zenodo(self, mock_buildtree): """downloads model from zenodo, checks with test artifacts""" mock_buildtree.return_value = None @@ -194,11 +194,11 @@ def test_zenodo(self, mock_buildtree): model_b.stage() self.assertEqual( - os.path.basename(model_a.registry.get("path")), - os.path.basename(model_b.registry.get("path")), + os.path.basename(model_a.registry.get_attr("path")), + os.path.basename(model_b.registry.get_attr("path")), ) self.assertEqual(model_a.name, model_b.name) - self.assertTrue(filecmp.cmp(model_a.registry.get("path"), model_b.registry.get("path"))) + self.assertTrue(filecmp.cmp(model_a.registry.get_attr("path"), model_b.registry.get_attr("path"))) def test_zenodo_fail(self): name = "mock_zenodo" diff --git a/tests/integration/test_model_docker.py b/tests/integration/test_model_docker.py index beefcb9..4efe165 100644 --- a/tests/integration/test_model_docker.py +++ b/tests/integration/test_model_docker.py @@ -78,20 +78,20 @@ def _make_model(self, subfolder: str, tag: str): workdir=str(model_dir), ) - @patch("floatcsep.infrastructure.registries.ForecastRegistry.build_tree") + @patch("floatcsep.infrastructure.registries.ModelFileRegistry.build_tree") def test_valid_model(self, mock_registry): model = self._make_model("valid", "testdocker_valid") model.stage() model.environment.run_command() # Should succeed with no exceptions - @patch("floatcsep.infrastructure.registries.ForecastRegistry.build_tree") + @patch("floatcsep.infrastructure.registries.ModelFileRegistry.build_tree") def test_invalid_image_build_fails(self, mock_registry): model = self._make_model("invalid_image", "testdocker_invalid_image") with self.assertRaises(RuntimeError) as err: model.environment.create_environment(force=True) self.assertIn("Docker build error", str(err.exception)) - @patch("floatcsep.infrastructure.registries.ForecastRegistry.build_tree") + @patch("floatcsep.infrastructure.registries.ModelFileRegistry.build_tree") def test_invalid_entrypoint_fails_to_run(self, mock_registry): model = self._make_model("invalid_entrypoint", "testdocker_invalid_entrypoint") model.stage() @@ -99,7 +99,7 @@ def test_invalid_entrypoint_fails_to_run(self, mock_registry): model.environment.run_command() self.assertIn("exited with code", str(err.exception)) - @patch("floatcsep.infrastructure.registries.ForecastRegistry.build_tree") + @patch("floatcsep.infrastructure.registries.ModelFileRegistry.build_tree") def test_invalid_permission_fails_to_run(self, mock_registry): model = self._make_model("invalid_permission", "testdocker_invalid_permission") model.stage() @@ -107,7 +107,7 @@ def test_invalid_permission_fails_to_run(self, mock_registry): model.environment.run_command() self.assertIn("exited with code", str(err.exception)) - @patch("floatcsep.infrastructure.registries.ForecastRegistry.build_tree") + @patch("floatcsep.infrastructure.registries.ModelFileRegistry.build_tree") def test_valid_custom_uid_gid(self, mock_registry): # todo: look into it model = self._make_model("valid_custom_uid-gid", "testdocker_uid_gid") diff --git a/tests/integration/test_model_infrastructure.py b/tests/integration/test_model_infrastructure.py index 96adb74..81b1bdb 100644 --- a/tests/integration/test_model_infrastructure.py +++ b/tests/integration/test_model_infrastructure.py @@ -35,11 +35,10 @@ def setUp(self): ) def test_time_independent_model_stage(self): - timewindows = [ + time_windows = [ [datetime(2023, 1, 1), datetime(2023, 1, 2)], ] - self.time_independent_model.stage(timewindows=timewindows) - print("a", self.time_independent_model.registry.as_dict()) + self.time_independent_model.stage(time_windows=time_windows) self.assertIn("2023-01-01_2023-01-02", self.time_independent_model.registry.forecasts) def test_time_independent_model_get_forecast(self): @@ -50,10 +49,10 @@ def test_time_independent_model_get_forecast(self): def test_time_independent_model_get_forecast_real(self): tstring = "2023-01-01_2023-01-02" - timewindows = [ + time_windows = [ [datetime(2023, 1, 1), datetime(2023, 1, 2)], ] - self.time_independent_model.stage(timewindows=timewindows) + self.time_independent_model.stage(time_windows=time_windows) forecast = self.time_independent_model.get_forecast(tstring) self.assertIsInstance(forecast, GriddedForecast) self.assertAlmostEqual(forecast.data[0, 0], 0.002739726027357392) # 1 / 365 days @@ -63,12 +62,12 @@ def test_time_independent_model_get_forecast_real(self): def test_time_dependent_model_stage(self, mock_venv, mock_conda): mock_venv.return_value = None mock_conda.return_value = None - timewindows = [ + time_windows = [ [datetime(2020, 1, 1), datetime(2020, 1, 2)], [datetime(2020, 1, 2), datetime(2020, 1, 3)], ] tstrings = ["2020-01-01_2020-01-02", "2020-01-02_2020-01-03"] - self.time_dependent_model.stage(timewindows=timewindows) + self.time_dependent_model.stage(time_windows=time_windows) self.assertIn(tstrings[0], self.time_dependent_model.registry.forecasts) self.assertIn(tstrings[1], self.time_dependent_model.registry.forecasts) @@ -78,11 +77,11 @@ def test_time_dependent_model_stage(self, mock_venv, mock_conda): def test_time_dependent_model_get_forecast(self, mock_venv, mock_conda): mock_venv.return_value = None mock_conda.return_value = None - timewindows = [ + time_windows = [ [datetime(2020, 1, 1), datetime(2020, 1, 2)], [datetime(2020, 1, 2), datetime(2020, 1, 3)], ] - self.time_dependent_model.stage(timewindows) + self.time_dependent_model.stage(time_windows) tstring = "2020-01-01_2020-01-02" forecast = self.time_dependent_model.get_forecast(tstring) self.assertIsNotNone(forecast) @@ -123,7 +122,7 @@ def forecast_(_): name = "mock" fname = os.path.join(self._dir, "model.csv") - with patch("floatcsep.readers.ForecastParsers.csv", forecast_): + with patch("floatcsep.readers.GriddedForecastParsers.csv", forecast_): model = self.init_model(name, fname) model.registry.build_tree([[start, end]]) forecast = model.get_forecast(timestring) diff --git a/tests/unit/test_experiment.py b/tests/unit/test_experiment.py index cce5cdc..ed7d116 100644 --- a/tests/unit/test_experiment.py +++ b/tests/unit/test_experiment.py @@ -41,7 +41,8 @@ def assertEqualExperiment(self, exp_a, exp_b): self.assertEqual(exp_a.registry.workdir, os.getcwd()) self.assertEqual(exp_a.registry.workdir, exp_b.registry.workdir) self.assertEqual(exp_a.start_date, exp_b.start_date) - self.assertEqual(exp_a.timewindows, exp_b.timewindows) + print(exp_a.time_windows, exp_b.time_windows) + self.assertEqual(exp_a.time_windows, exp_b.time_windows) self.assertEqual(exp_a.exp_class, exp_b.exp_class) self.assertEqual(exp_a.region, exp_b.region) numpy.testing.assert_equal(exp_a.magnitudes, exp_b.magnitudes) diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index 35bd2a4..125294a 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -2,7 +2,7 @@ from unittest import TestCase from floatcsep.model import TimeIndependentModel -from floatcsep.infrastructure.registries import ForecastRegistry +from floatcsep.infrastructure.registries import ModelRegistry from floatcsep.infrastructure.repositories import GriddedForecastRepository from unittest.mock import patch, MagicMock, mock_open from floatcsep.model import TimeDependentModel @@ -27,7 +27,7 @@ def assertEqualModel(model_a, model_b): raise AssertionError("Models are not equal") for i in keys_a: - if isinstance(getattr(model_a, i), ForecastRegistry): + if isinstance(getattr(model_a, i), ModelRegistry): continue if not (getattr(model_a, i) == getattr(model_b, i)): print(getattr(model_a, i), getattr(model_b, i)) @@ -59,7 +59,7 @@ def test_from_filesystem(self): @patch("os.makedirs") @patch("floatcsep.model.TimeIndependentModel.get_source") - @patch("floatcsep.infrastructure.registries.ForecastRegistry.build_tree") + @patch("floatcsep.infrastructure.registries.ModelFileRegistry.build_tree") def test_stage_creates_directory(self, mock_build_tree, mock_get_source, mock_makedirs): """Test stage method creates directory.""" model = self.init_model("mock", "mockfile.csv") @@ -159,7 +159,7 @@ class TestTimeDependentModel(TestModel): def setUp(self): # Patches - self.patcher_registry = patch("floatcsep.model.ForecastRegistry") + self.patcher_registry = patch("floatcsep.model.ModelRegistry.factory") self.patcher_repository = patch("floatcsep.model.ForecastRepository.factory") self.patcher_environment = patch("floatcsep.model.EnvironmentFactory.get_env") self.patcher_get_source = patch( @@ -167,14 +167,14 @@ def setUp(self): ) # Patch the get_source method on Model # Start patches - self.mock_registry = self.patcher_registry.start() + self.mock_registry_factory = self.patcher_registry.start() self.mock_repository_factory = self.patcher_repository.start() self.mock_environment = self.patcher_environment.start() self.mock_get_source = self.patcher_get_source.start() # Mock instances self.mock_registry_instance = MagicMock() - self.mock_registry.return_value = self.mock_registry_instance + self.mock_registry_factory.return_value = self.mock_registry_instance self.mock_repository_instance = MagicMock() self.mock_repository_factory.return_value = self.mock_repository_instance @@ -185,7 +185,7 @@ def setUp(self): # Set attributes on the mock objects self.mock_registry_instance.workdir = "/path/to/workdir" self.mock_registry_instance.path = "/path/to/model" - self.mock_registry_instance.get.return_value = ( + self.mock_registry_instance.get_args_key.return_value = ( "/path/to/args_file.txt" # Mocking the return of the registry call ) @@ -204,7 +204,7 @@ def tearDown(self): def test_init(self): # Assertions to check if the components were instantiated correctly - self.mock_registry.assert_called_once_with( + self.mock_registry_factory.assert_called_once_with( workdir=os.getcwd(), path=self.model_path, fmt='csv' ) # Ensure the registry is initialized correctly self.mock_repository_factory.assert_called_once_with( @@ -224,13 +224,13 @@ def test_init(self): def test_stage(self, mk): self.model.force_stage = True # Force staging to occur - self.model.stage(timewindows=["2020-01-01_2020-12-31"]) + self.model.stage(time_windows=["2020-01-01_2020-12-31"]) self.mock_get_source.assert_called_once_with( self.model.zenodo_id, self.model.giturl, branch=self.model.repo_hash ) self.mock_registry_instance.build_tree.assert_called_once_with( - timewindows=["2020-01-01_2020-12-31"], + time_windows=["2020-01-01_2020-12-31"], model_class="TimeDependentModel", prefix=self.model.__dict__.get("prefix", self.name), args_file=self.model.__dict__.get("args_file", None), @@ -254,7 +254,7 @@ def test_create_forecast(self, prep_args_mock): self.model.create_forecast(tstring, force=True) self.mock_environment_instance.run_command.assert_called_once_with( - f'{self.func} {self.model.registry.get("args_file")}' + f'{self.func} {self.model.registry.get_args_key()}' ) @patch("builtins.open", new_callable=mock_open) @@ -279,7 +279,7 @@ def test_prepare_args(self, mock_json_dump, mock_json_load, mock_open_file): ] # Call the method - args_file_path = self.model.registry.get("args_file") + args_file_path = self.model.registry.get_args_key() self.model.prepare_args(start_date, end_date, custom_arg="value") mock_open_file.assert_any_call(args_file_path, "r") mock_open_file.assert_any_call(args_file_path, "w") @@ -293,7 +293,7 @@ def test_prepare_args(self, mock_json_dump, mock_json_load, mock_open_file): ) json_file_path = "/path/to/args_file.json" - self.model.registry.get.return_value = json_file_path + self.model.registry.get_args_key.return_value = json_file_path self.model.prepare_args(start_date, end_date, custom_arg="value") mock_open_file.assert_any_call(json_file_path, "r") diff --git a/tests/unit/test_plot_handler.py b/tests/unit/test_plot_handler.py index d1fd852..828b1f3 100644 --- a/tests/unit/test_plot_handler.py +++ b/tests/unit/test_plot_handler.py @@ -15,7 +15,7 @@ def test_plot_results(self, mock_timewindow2str, mock_savefig): plot_handler.plot_results(mock_experiment) - mock_timewindow2str.assert_called_once_with(mock_experiment.timewindows) + mock_timewindow2str.assert_called_once_with(mock_experiment.time_windows) mock_test.plot_results.assert_called_once_with( ["2021-01-01", "2021-12-31"], mock_experiment.models, mock_experiment.registry ) @@ -63,7 +63,7 @@ def test_plot_catalogs( mock_parse_plot_config.return_value = {"projection": "Mercator"} mock_parse_projection.return_value = MagicMock() - mock_experiment.registry.get_figure.return_value = "cat.png" + mock_experiment.registry.get_figure_key.return_value = "cat.png" plot_handler.plot_catalogs(mock_experiment) diff --git a/tests/unit/test_readers.py b/tests/unit/test_readers.py index f4935a1..1f552e3 100644 --- a/tests/unit/test_readers.py +++ b/tests/unit/test_readers.py @@ -25,7 +25,7 @@ def tearDownClass(cls) -> None: def test_parse_csv(self): fname = os.path.join(self._dir, "model.csv") numpy.seterr(all="ignore") - rates, region, mags = readers.ForecastParsers.csv(fname) + rates, region, mags = readers.GriddedForecastParsers.csv(fname) rts = numpy.array([[1.0, 0.1], [1.0, 0.1], [1.0, 0.1], [1.0, 0.1]]) orgs = numpy.array([[0.0, 0.0], [0.1, 0], [0.0, 0.1], [0.1, 0.1]]) @@ -39,7 +39,7 @@ def test_parse_csv(self): def test_parse_dat(self): fname = csep.utils.datasets.helmstetter_mainshock_fname - rates, region, mags = readers.ForecastParsers.dat(fname) + rates, region, mags = readers.GriddedForecastParsers.dat(fname) forecast = csep.load_gridded_forecast(fname) self.assertEqual(forecast.region, region) @@ -50,7 +50,7 @@ def test_parse_csv_qtree(self): fname = os.path.join(self._dir, "qtree", "TEAM=N10L11.csv") numpy.seterr(all="ignore") - rates, region, mags = readers.ForecastParsers.csv(fname) + rates, region, mags = readers.GriddedForecastParsers.csv(fname) poly = numpy.array( [[-180.0, 66.51326], [-180.0, 79.171335], [-135.0, 79.171335], [-135.0, 66.51326]] @@ -61,7 +61,7 @@ def test_parse_csv_qtree(self): self.assertEqual(8089, rates.shape[0]) numpy.testing.assert_allclose(poly, region.polygons[2].points) - rates2, region2, mags2 = readers.ForecastParsers.quadtree(fname) + rates2, region2, mags2 = readers.GriddedForecastParsers.quadtree(fname) numpy.testing.assert_allclose(rates, rates2) numpy.testing.assert_allclose( [i.points for i in region.polygons], [i.points for i in region2.polygons] @@ -78,7 +78,7 @@ def test_parse_xml(self): ) numpy.seterr(all="ignore") - rates, region, mags = readers.ForecastParsers.xml(fname) + rates, region, mags = readers.GriddedForecastParsers.xml(fname) orgs = numpy.array([12.6, 38.3]) poly = numpy.array([[12.6, 38.3], [12.6, 38.4], [12.7, 38.4], [12.7, 38.3]]) @@ -94,7 +94,7 @@ def test_parse_xml(self): def test_serialize_hdf5(self): numpy.seterr(all="ignore") fname = os.path.join(self._dir, "model.csv") - rates, region, mags = readers.ForecastParsers.csv(fname) + rates, region, mags = readers.GriddedForecastParsers.csv(fname) fname_db = os.path.join(self._dir, "model.hdf5") readers.HDF5Serializer.grid2hdf5(rates, region, mags, hdf5_filename=fname_db) @@ -105,7 +105,7 @@ def test_serialize_hdf5(self): def test_parse_hdf5(self): fname = os.path.join(self._dir, "model_h5.hdf5") - rates, region, mags = readers.ForecastParsers.hdf5(fname) + rates, region, mags = readers.GriddedForecastParsers.hdf5(fname) orgs = numpy.array([[0.0, 0.0], [0.1, 0], [0.0, 0.1], [0.1, 0.1]]) poly_3 = numpy.array([[0.1, 0.1], [0.1, 0.2], [0.2, 0.2], [0.2, 0.1]]) diff --git a/tests/unit/test_registry.py b/tests/unit/test_registry.py index 369081f..587523e 100644 --- a/tests/unit/test_registry.py +++ b/tests/unit/test_registry.py @@ -1,45 +1,46 @@ +import os import unittest from datetime import datetime from unittest.mock import patch, MagicMock -from floatcsep.infrastructure.registries import ForecastRegistry +from floatcsep.infrastructure.registries import ModelFileRegistry, ExperimentFileRegistry -class TestForecastRegistry(unittest.TestCase): +class TestModelFileRegistry(unittest.TestCase): def setUp(self): - self.registry_file = ForecastRegistry( + self.registry_for_filebased_model = ModelFileRegistry( workdir="/test/workdir", path="/test/workdir/model.txt" ) - self.registry_folder = ForecastRegistry( + self.registry_for_folderbased_model = ModelFileRegistry( workdir="/test/workdir", path="/test/workdir/model" ) def test_call(self): - self.registry_file._parse_arg = MagicMock(return_value="path") - result = self.registry_file.get("path") + self.registry_for_filebased_model._parse_arg = MagicMock(return_value="path") + result = self.registry_for_filebased_model.get_attr("path") self.assertEqual(result, "/test/workdir/model.txt") @patch("os.path.isdir") def test_dir(self, mock_isdir): mock_isdir.return_value = False - self.assertEqual(self.registry_file.dir, "/test/workdir") + self.assertEqual(self.registry_for_filebased_model.dir, "/test/workdir") mock_isdir.return_value = True - self.assertEqual(self.registry_folder.dir, "/test/workdir/model") + self.assertEqual(self.registry_for_folderbased_model.dir, "/test/workdir/model") def test_fmt(self): - self.registry_file.database = "test.db" - self.assertEqual(self.registry_file.fmt, "db") - self.registry_file.database = None - self.assertEqual(self.registry_file.fmt, "txt") + self.registry_for_filebased_model.database = "test.db" + self.assertEqual(self.registry_for_filebased_model.fmt, "db") + self.registry_for_filebased_model.database = None + self.assertEqual(self.registry_for_filebased_model.fmt, "txt") def test_parse_arg(self): - self.assertEqual(self.registry_file._parse_arg("arg"), "arg") - self.assertRaises(Exception, self.registry_file._parse_arg, 123) + self.assertEqual(self.registry_for_filebased_model._parse_arg("arg"), "arg") + self.assertRaises(Exception, self.registry_for_filebased_model._parse_arg, 123) def test_as_dict(self): self.assertEqual( - self.registry_file.as_dict(), + self.registry_for_filebased_model.as_dict(), { "args_file": None, "database": None, @@ -51,44 +52,120 @@ def test_as_dict(self): ) def test_abs(self): - result = self.registry_file.abs("file.txt") + result = self.registry_for_filebased_model.abs("file.txt") self.assertTrue(result.endswith("/test/workdir/file.txt")) - def test_absdir(self): - result = self.registry_file.abs_dir("model.txt") + def test_abs_dir(self): + result = self.registry_for_filebased_model.abs_dir("model.txt") self.assertTrue(result.endswith("/test/workdir")) @patch("floatcsep.infrastructure.registries.exists") - def test_fileexists(self, mock_exists): + def test_file_exists(self, mock_exists): mock_exists.return_value = True - self.registry_file.get = MagicMock(return_value="/test/path/file.txt") - self.assertTrue(self.registry_file.file_exists("file.txt")) + self.registry_for_filebased_model.get_attr = MagicMock(return_value="/test/path/file.txt") + self.assertTrue(self.registry_for_filebased_model.file_exists("file.txt")) @patch("os.makedirs") @patch("os.listdir") def test_build_tree_time_independent(self, mock_listdir, mock_makedirs): - timewindows = [[datetime(2023, 1, 1), datetime(2023, 1, 2)]] - self.registry_file.build_tree( - timewindows=timewindows, model_class="TimeIndependentModel" + time_windows = [[datetime(2023, 1, 1), datetime(2023, 1, 2)]] + self.registry_for_filebased_model.build_tree( + time_windows=time_windows, model_class="TimeIndependentModel" ) - self.assertIn("2023-01-01_2023-01-02", self.registry_file.forecasts) - # self.assertIn("2023-01-01_2023-01-02", self.registry_file.inventory) + self.assertIn("2023-01-01_2023-01-02", self.registry_for_filebased_model.forecasts) + # self.assertIn("2023-01-01_2023-01-02", self.registry_for_filebased_model.inventory) @patch("os.makedirs") @patch("os.listdir") def test_build_tree_time_dependent(self, mock_listdir, mock_makedirs): mock_listdir.return_value = ["forecast_1.csv"] - timewindows = [ + time_windows = [ [datetime(2023, 1, 1), datetime(2023, 1, 2)], [datetime(2023, 1, 2), datetime(2023, 1, 3)], ] - self.registry_folder.build_tree( - timewindows=timewindows, model_class="TimeDependentModel", prefix="forecast" + self.registry_for_folderbased_model.build_tree( + time_windows=time_windows, model_class="TimeDependentModel", prefix="forecast" ) - self.assertIn("2023-01-01_2023-01-02", self.registry_folder.forecasts) - # self.assertTrue(self.registry_folder.inventory["2023-01-01_2023-01-02"]) - self.assertIn("2023-01-02_2023-01-03", self.registry_folder.forecasts) - # self.assertTrue(self.registry_folder.inventory["2023-01-02_2023-01-03"]) + self.assertIn("2023-01-01_2023-01-02", self.registry_for_folderbased_model.forecasts) + self.assertIn("2023-01-02_2023-01-03", self.registry_for_folderbased_model.forecasts) + + +class TestExperimentFileRegistry(unittest.TestCase): + + def setUp(self): + self.registry = ExperimentFileRegistry(workdir="/test/workdir") + + def test_initialization(self): + self.assertEqual(self.registry.workdir, "/test/workdir") + self.assertEqual(self.registry.run_dir, "results") + self.assertEqual(self.registry.results, {}) + self.assertEqual(self.registry.test_catalogs, {}) + self.assertEqual(self.registry.figures, {}) + self.assertEqual(self.registry.model_registries, {}) + + def test_add_and_get_model_registry(self): + model_mock = MagicMock() + model_mock.name = "TestModel" + model_mock.registry = MagicMock(spec=ModelFileRegistry) + + self.registry.add_model_registry(model_mock) + self.assertIn("TestModel", self.registry.model_registries) + self.assertEqual(self.registry.get_model_registry("TestModel"), model_mock.registry) + + @patch("os.makedirs") + def test_build_tree(self, mock_makedirs): + time_windows = [[datetime(2023, 1, 1), datetime(2023, 1, 2)]] + models = [MagicMock(name="Model1"), MagicMock(name="Model2")] + tests = [MagicMock(name="Test1")] + + self.registry.build_tree(time_windows, models, tests) + + timewindow_str = "2023-01-01_2023-01-02" + self.assertIn(timewindow_str, self.registry.results) + self.assertIn(timewindow_str, self.registry.test_catalogs) + self.assertIn(timewindow_str, self.registry.figures) + + def test_get_test_catalog_key(self): + self.registry.test_catalogs = {"2023-01-01_2023-01-02": "some/path/to/catalog.json"} + result = self.registry.get_test_catalog_key("2023-01-01_2023-01-02") + self.assertTrue(result.endswith("results/some/path/to/catalog.json")) + + def test_get_result_key(self): + self.registry.results = { + "2023-01-01_2023-01-02": { + "Test1": { + "Model1": "some/path/to/result.json" + } + } + } + result = self.registry.get_result_key("2023-01-01_2023-01-02", "Test1", "Model1") + self.assertTrue(result.endswith("results/some/path/to/result.json")) + + def test_get_figure_key(self): + self.registry.figures = { + "2023-01-01_2023-01-02": { + "Test1": "some/path/to/figure.png", + "catalog_map": "some/path/to/catalog_map.png", + "catalog_time": "some/path/to/catalog_time.png", + "forecasts": {"Model1": "some/path/to/forecast.png"} + } + } + result = self.registry.get_figure_key("2023-01-01_2023-01-02", "Test1") + self.assertTrue(result.endswith("results/some/path/to/figure.png")) + + @patch("floatcsep.infrastructure.registries.exists") + def test_result_exist(self, mock_exists): + mock_exists.return_value = True + self.registry.results = { + "2023-01-01_2023-01-02": { + "Test1": { + "Model1": "some/path/to/result.json" + } + } + } + result = self.registry.result_exist("2023-01-01_2023-01-02", "Test1", "Model1") + self.assertTrue(result) + mock_exists.assert_called() if __name__ == "__main__": diff --git a/tests/unit/test_reporting.py b/tests/unit/test_reporting.py index 709f745..110efff 100644 --- a/tests/unit/test_reporting.py +++ b/tests/unit/test_reporting.py @@ -25,7 +25,7 @@ def test_generate_standard_report(self, mock_markdown_report): # Mock experiment without a custom report function mock_experiment = MagicMock() mock_experiment.postprocess.get.return_value = None - mock_experiment.registry.get_figure.return_value = "figure_path" + mock_experiment.registry.get_figure_key.return_value = "figure_path" mock_experiment.magnitudes = [0, 1] # Call the generate_report function reporting.generate_report(mock_experiment) diff --git a/tests/unit/test_repositories.py b/tests/unit/test_repositories.py index 1cab55a..335ed66 100644 --- a/tests/unit/test_repositories.py +++ b/tests/unit/test_repositories.py @@ -4,8 +4,8 @@ from csep.core.forecasts import GriddedForecast -from floatcsep.utils.readers import ForecastParsers -from floatcsep.infrastructure.registries import ForecastRegistry +from floatcsep.utils.readers import GriddedForecastParsers +from floatcsep.infrastructure.registries import ModelFileRegistry from floatcsep.infrastructure.repositories import ( CatalogForecastRepository, GriddedForecastRepository, @@ -17,17 +17,19 @@ class TestCatalogForecastRepository(unittest.TestCase): def setUp(self): - self.registry = MagicMock(spec=ForecastRegistry) + self.registry = MagicMock(spec=ModelFileRegistry) #todo: Factory registry self.registry.__call__ = MagicMock(return_value="a_duck") + self.registry.fmt = 'csv' @patch("csep.load_catalog_forecast") def test_initialization(self, mock_load_catalog_forecast): repo = CatalogForecastRepository(self.registry, lazy_load=True) self.assertTrue(repo.lazy_load) - @patch("csep.load_catalog_forecast") + @patch("floatcsep.readers.CatalogForecastParsers.csv") def test_load_forecast(self, mock_load_catalog_forecast): repo = CatalogForecastRepository(self.registry) + mock_load_catalog_forecast.return_value = "forecatto" forecast = repo.load_forecast("2023-01-01_2023-01-02") self.assertEqual(forecast, "forecatto") @@ -36,7 +38,7 @@ def test_load_forecast(self, mock_load_catalog_forecast): forecasts = repo.load_forecast(["2023-01-01_2023-01-01", "2023-01-02_2023-01-03"]) self.assertEqual(forecasts, ["forecatto", "forecatto"]) - @patch("csep.load_catalog_forecast") + @patch("floatcsep.readers.CatalogForecastParsers.csv") def test_load_single_forecast(self, mock_load_catalog_forecast): # Test _load_single_forecast repo = CatalogForecastRepository(self.registry) @@ -48,7 +50,7 @@ def test_load_single_forecast(self, mock_load_catalog_forecast): class TestGriddedForecastRepository(unittest.TestCase): def setUp(self): - self.registry = MagicMock(spec=ForecastRegistry) + self.registry = MagicMock(spec=ModelFileRegistry) #todo: Factory registry self.registry.fmt = "hdf5" self.registry.__call__ = MagicMock(return_value="a_duck") @@ -56,7 +58,7 @@ def test_initialization(self): repo = GriddedForecastRepository(self.registry, lazy_load=False) self.assertFalse(repo.lazy_load) - @patch.object(ForecastParsers, "hdf5") + @patch.object(GriddedForecastParsers, "hdf5") def test_load_forecast(self, mock_parser): # Mock parser return values mock_parser.return_value = ("rates", "region", "mags") @@ -77,7 +79,7 @@ def test_load_forecast(self, mock_parser): self.assertEqual(forecasts, ["forecatto", "forecatto"]) self.assertEqual(mock_method.call_count, 2) - @patch.object(ForecastParsers, "hdf5") + @patch.object(GriddedForecastParsers, "hdf5") def test_get_or_load_forecast(self, mock_parser): mock_parser.return_value = ("rates", "region", "mags") repo = GriddedForecastRepository(self.registry, lazy_load=False) @@ -98,7 +100,7 @@ def test_get_or_load_forecast(self, mock_parser): @patch.object(GriddedForecast, "__init__", return_value=None) @patch.object(GriddedForecast, "event_count", new_callable=PropertyMock) @patch.object(GriddedForecast, "scale") - @patch.object(ForecastParsers, "hdf5") + @patch.object(GriddedForecastParsers, "hdf5") def test_load_single_forecast(self, mock_parser, mock_scale, mock_count, mock_init): # Mock parser return values mock_count.return_value = 2 @@ -119,7 +121,7 @@ def test_load_single_forecast(self, mock_parser, mock_scale, mock_count, mock_in end_time=datetime.datetime(2024, 1, 1), ) - @patch.object(ForecastParsers, "hdf5") + @patch.object(GriddedForecastParsers, "hdf5") def test_lazy_load_behavior(self, mock_parser): mock_parser.return_value = ("rates", "region", "mags") # Test lazy_load behavior @@ -138,10 +140,10 @@ def test_lazy_load_behavior(self, mock_parser): self.assertEqual(forecast, "forecatto") self.assertNotIn("2023-01-02_2023-01-03", repo.forecasts) - @patch("floatcsep.infrastructure.registries.ForecastRegistry") - def test_equal(self, MockForecastRegistry): + @patch("floatcsep.infrastructure.registries.ModelFileRegistry") + def test_equal(self, MockModelFileRegistry): - self.registry = MockForecastRegistry() + self.registry = MockModelFileRegistry() self.repo1 = CatalogForecastRepository(self.registry) self.repo2 = CatalogForecastRepository(self.registry) @@ -160,9 +162,10 @@ def test_equal(self, MockForecastRegistry): class TestResultsRepository(unittest.TestCase): - @patch("floatcsep.infrastructure.repositories.ExperimentRegistry") - def setUp(self, MockRegistry): - self.mock_registry = MockRegistry() + @patch("floatcsep.infrastructure.repositories.ExperimentRegistry.factory") + def setUp(self, mock_registry): + self.mock_registry = MagicMock() + self.mock_registry.return_value = mock_registry() self.results_repo = ResultsRepository(self.mock_registry) def test_initialization(self): @@ -191,9 +194,10 @@ def test_write_result(self, mock_open, mock_json_dump): class TestCatalogRepository(unittest.TestCase): - @patch("floatcsep.infrastructure.repositories.ExperimentRegistry") - def setUp(self, MockRegistry): - self.mock_registry = MockRegistry() + @patch("floatcsep.infrastructure.repositories.ExperimentRegistry.factory") + def setUp(self, mock_registry): + self.mock_registry = MagicMock() + self.mock_registry.return_value = mock_registry() self.catalog_repo = CatalogRepository(self.mock_registry) def test_initialization(self): diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 341339f..e4940a2 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -11,11 +11,11 @@ import floatcsep.utils.accessors from floatcsep.utils.helpers import ( parse_timedelta_string, - timewindows_ti, + time_windows_ti, read_time_cfg, read_region_cfg, parse_csep_func, - timewindows_td, + time_windows_td, ) root_dir = os.path.dirname(os.path.abspath(__file__)) @@ -58,19 +58,19 @@ def test_parse_time_window(self): dt = "1decade" self.assertRaises(ValueError, parse_timedelta_string, dt) - def test_timewindows_ti(self): + def test_time_windows_ti(self): start = datetime(2014, 1, 1) end = datetime(2022, 1, 1) - self.assertEqual(timewindows_ti(start_date=start, end_date=end), [(start, end)]) + self.assertEqual(time_windows_ti(start_date=start, end_date=end), [(start, end)]) t1 = [ (datetime(2014, 1, 1), datetime(2018, 1, 1)), (datetime(2018, 1, 1), datetime(2022, 1, 1)), ] - self.assertEqual(timewindows_ti(start_date=start, end_date=end, intervals=2), t1) - self.assertEqual(timewindows_ti(start_date=start, end_date=end, horizon="4-years"), t1) - self.assertEqual(timewindows_ti(start_date=start, intervals=2, horizon="4-years"), t1) + self.assertEqual(time_windows_ti(start_date=start, end_date=end, intervals=2), t1) + self.assertEqual(time_windows_ti(start_date=start, end_date=end, horizon="4-years"), t1) + self.assertEqual(time_windows_ti(start_date=start, intervals=2, horizon="4-years"), t1) t2 = [ (datetime(2014, 1, 1, 0, 0), datetime(2015, 2, 22, 10, 17, 8, 571428)), @@ -93,18 +93,18 @@ def test_timewindows_ti(self): ), (datetime(2020, 11, 9, 13, 42, 51, 428571), datetime(2022, 1, 1, 0, 0)), ] - self.assertEqual(timewindows_ti(start_date=start, end_date=end, intervals=7), t2) + self.assertEqual(time_windows_ti(start_date=start, end_date=end, intervals=7), t2) - def test_timewindows_td(self): + def test_time_windows_td(self): start = datetime(2010, 1, 1) end = datetime(2020, 1, 1) self.assertEqual( - timewindows_td(start_date=start, end_date=end, timeintervals=1), [(start, end)] + time_windows_td(start_date=start, end_date=end, timeintervals=1), [(start, end)] ) self.assertEqual( - timewindows_td(start_date=start, end_date=end, timeintervals=5), + time_windows_td(start_date=start, end_date=end, timeintervals=5), [ (datetime(2010, 1, 1, 0, 0), datetime(2012, 1, 1, 9, 36)), (datetime(2012, 1, 1, 9, 36), datetime(2013, 12, 31, 19, 12)), @@ -115,11 +115,11 @@ def test_timewindows_td(self): ) self.assertEqual( - timewindows_td(start_date=start, timeintervals=1, timehorizon="1-years"), + time_windows_td(start_date=start, timeintervals=1, timehorizon="1-years"), [(datetime(2010, 1, 1, 0, 0), datetime(2011, 1, 1, 0, 0))], ) self.assertEqual( - timewindows_td(start_date=start, timeintervals=5, timehorizon="5-days"), + time_windows_td(start_date=start, timeintervals=5, timehorizon="5-days"), [ (datetime(2010, 1, 1, 0, 0), datetime(2010, 1, 6, 0, 0)), (datetime(2010, 1, 6, 0, 0), datetime(2010, 1, 11, 0, 0)), @@ -129,19 +129,19 @@ def test_timewindows_td(self): ], ) self.assertEqual( - timewindows_td( + time_windows_td( start_date=start, end_date=end, timehorizon="10-years", timeoffset="10-years" ), [(datetime(2010, 1, 1, 0, 0), datetime(2020, 1, 1, 0, 0))], ) self.assertEqual( - timewindows_td( + time_windows_td( start_date=start, end_date=end, timehorizon="12-years", timeoffset="10-years" ), [(datetime(2010, 1, 1, 0, 0), datetime(2022, 1, 1, 0, 0))], ) self.assertEqual( - timewindows_td( + time_windows_td( start_date=start, end_date=end, timehorizon="5-years", timeoffset="5-years" ), [ @@ -150,7 +150,7 @@ def test_timewindows_td(self): ], ) self.assertEqual( - timewindows_td( + time_windows_td( start_date=start, end_date=end, timehorizon="5-years", timeoffset="3-years" ), [ @@ -160,7 +160,7 @@ def test_timewindows_td(self): ], ) self.assertEqual( - timewindows_td( + time_windows_td( start_date=start, end_date=datetime(2010, 2, 1), timehorizon="14-days", @@ -174,7 +174,7 @@ def test_timewindows_td(self): ], ) self.assertEqual( - timewindows_td( + time_windows_td( start_date=start, timeintervals=3, timehorizon="3-years", diff --git a/tutorials/case_h/custom_report.py b/tutorials/case_h/custom_report.py index d705d5d..b8ed4b2 100644 --- a/tutorials/case_h/custom_report.py +++ b/tutorials/case_h/custom_report.py @@ -30,8 +30,8 @@ def main(experiment): report.add_figure( f"Input catalog", [ - experiment.registry.get_figure("main_catalog_map"), - experiment.registry.get_figure("main_catalog_time"), + experiment.registry.get_figure_key("main_catalog_map"), + experiment.registry.get_figure_key("main_catalog_time"), ], level=3, ncols=1, @@ -44,7 +44,7 @@ def main(experiment): # Include results from Experiment test = experiment.tests[0] for model in experiment.models: - fig_path = experiment.registry.get_figure(timestr, f"{test.name}_{model.name}") + fig_path = experiment.registry.get_figure_key(timestr, f"{test.name}_{model.name}") report.add_figure( f"{test.name}: {model.name}", fig_path,