Skip to content

Commit 71f1ea9

Browse files
committed
refac: abstracted filepath management functionality as a Mixin to be used in file-base registries. Created ExperimentRegistry abstract and File-based concrete classes. Renamed registries' interface to add/get_*_key. abstracted registry logging to be a function in the logger module.
tests: added unit tests for experiment file registry
1 parent 796a51e commit 71f1ea9

File tree

12 files changed

+369
-271
lines changed

12 files changed

+369
-271
lines changed

floatcsep/evaluation.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def plot_results(
287287
# Regular consistency/comparative test plots (e.g., many models)
288288
try:
289289
for time_str in timewindow:
290-
fig_path = registry.get_figure(time_str, self.name)
290+
fig_path = registry.get_figure_key(time_str, self.name)
291291
results = self.read_results(time_str, models)
292292
ax = func(results, plot_args=fargs, **fkwargs)
293293
if "code" in fargs:
@@ -307,7 +307,7 @@ def plot_results(
307307
registry.figures[time_str][fig_name] = os.path.join(
308308
time_str, "figures", fig_name
309309
)
310-
fig_path = registry.get_figure(time_str, fig_name)
310+
fig_path = registry.get_figure_key(time_str, fig_name)
311311
ax = func(result, plot_args=fargs, **fkwargs, show=False)
312312
if "code" in fargs:
313313
exec(fargs["code"])
@@ -318,7 +318,7 @@ def plot_results(
318318
pyplot.show()
319319

320320
elif self.type in ["sequential", "sequential_comparative", "batch"]:
321-
fig_path = registry.get_figure(timewindow[-1], self.name)
321+
fig_path = registry.get_figure_key(timewindow[-1], self.name)
322322
results = self.read_results(timewindow[-1], models)
323323
ax = func(results, plot_args=fargs, **fkwargs)
324324

floatcsep/experiment.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
parse_nested_dicts,
2626
)
2727
from floatcsep.infrastructure.engine import Task, TaskGraph
28+
from floatcsep.infrastructure.logger import log_models_tree, log_results_tree
2829

2930
log = logging.getLogger("floatLogger")
3031

@@ -118,7 +119,7 @@ def __init__(
118119
os.makedirs(os.path.join(workdir, rundir), exist_ok=True)
119120

120121
self.name = name if name else "floatingExp"
121-
self.registry = ExperimentRegistry(workdir, rundir)
122+
self.registry = ExperimentRegistry.factory(workdir=workdir, run_dir=rundir)
122123
self.results_repo = ResultsRepository(self.registry)
123124
self.catalog_repo = CatalogRepository(self.registry)
124125

@@ -296,7 +297,7 @@ def stage_models(self) -> None:
296297
log.info("Staging models")
297298
for i in self.models:
298299
i.stage(self.time_windows)
299-
self.registry.add_forecast_registry(i)
300+
self.registry.add_model_registry(i)
300301

301302
def set_tests(self, test_config: Union[str, Dict, List]) -> list:
302303
"""
@@ -379,9 +380,9 @@ def set_tasks(self) -> None:
379380
self.registry.build_tree(self.time_windows, self.models, self.tests)
380381

381382
log.debug("Pre-run forecast summary")
382-
self.registry.log_forecast_trees(self.time_windows)
383+
log_models_tree(log, self.registry, self.time_windows)
383384
log.debug("Pre-run result summary")
384-
self.registry.log_results_tree()
385+
log_results_tree(log, self.registry)
385386

386387
log.info("Setting up experiment's tasks")
387388

@@ -540,9 +541,9 @@ def run(self) -> None:
540541
self.task_graph.run()
541542
log.info("Calculation completed")
542543
log.debug("Post-run forecast registry")
543-
self.registry.log_forecast_trees(self.time_windows)
544+
log_models_tree(log, self.registry, self.time_windows)
544545
log.debug("Post-run result summary")
545-
self.registry.log_results_tree()
546+
log_results_tree(log, self.registry)
546547

547548
def read_results(self, test: Evaluation, window: str) -> List:
548549
"""
@@ -559,7 +560,7 @@ def make_repr(self) -> None:
559560
560561
"""
561562
log.info("Creating reproducibility config file")
562-
repr_config = self.registry.get("repr_config")
563+
repr_config = self.registry.get_attr("repr_config")
563564

564565
# Dropping region to results folder if it is a file
565566
region_path = self.region_config.get("path", False)
@@ -801,8 +802,8 @@ def get_filecomp(self):
801802
for tw in win_orig:
802803
results[test.name][tw] = dict.fromkeys(models_orig)
803804
for model in models_orig:
804-
orig_path = self.original.registry.get_result(tw, test, model)
805-
repr_path = self.reproduced.registry.get_result(tw, test, model)
805+
orig_path = self.original.registry.get_result_key(tw, test, model)
806+
repr_path = self.reproduced.registry.get_result_key(tw, test, model)
806807

807808
results[test.name][tw][model] = {
808809
"hash": (self.get_hash(orig_path) == self.get_hash(repr_path)),
@@ -811,8 +812,8 @@ def get_filecomp(self):
811812
else:
812813
results[test.name] = dict.fromkeys(models_orig)
813814
for model in models_orig:
814-
orig_path = self.original.registry.get_result(win_orig[-1], test, model)
815-
repr_path = self.reproduced.registry.get_result(win_orig[-1], test, model)
815+
orig_path = self.original.registry.get_result_key(win_orig[-1], test, model)
816+
repr_path = self.reproduced.registry.get_result_key(win_orig[-1], test, model)
816817
results[test.name][model] = {
817818
"hash": (self.get_hash(orig_path) == self.get_hash(repr_path)),
818819
"byte2byte": filecmp.cmp(orig_path, repr_path),

floatcsep/infrastructure/logger.py

+71
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,74 @@ def set_console_log_level(log_level):
6060
for handler in logger.handlers:
6161
if isinstance(handler, logging.StreamHandler):
6262
handler.setLevel(log_level)
63+
64+
65+
66+
67+
def log_models_tree(log, experiment_registry, time_windows):
68+
"""
69+
Logs the forecasts for all models managed by this ExperimentFileRegistry.
70+
"""
71+
log.debug("===================")
72+
log.debug(f" Total Time Windows: {len(time_windows)}")
73+
for model_name, registry in experiment_registry.model_registries.items():
74+
log.debug(f" Model: {model_name}")
75+
exists_group = []
76+
not_exist_group = []
77+
78+
for timewindow, filepath in registry.forecasts.items():
79+
if registry.forecast_exists(timewindow):
80+
exists_group.append(timewindow)
81+
else:
82+
not_exist_group.append(timewindow)
83+
84+
log.debug(f" Existing forecasts: {len(exists_group)}")
85+
log.debug(f" Missing forecasts: {len(not_exist_group)}")
86+
for timewindow in not_exist_group:
87+
log.debug(f" Time Window: {timewindow}")
88+
log.debug("===================")
89+
90+
91+
def log_results_tree(log, experiment_registry):
92+
"""
93+
Logs a summary of the results dictionary, sorted by test.
94+
For each test and time window, it logs whether all models have results,
95+
or if some results are missing, and specifies which models are missing.
96+
"""
97+
log.debug("===================")
98+
99+
total_results = results_exist_count = results_not_exist_count = 0
100+
101+
# Get all unique test names and sort them
102+
all_tests = sorted(
103+
{test_name for tests in experiment_registry.results.values() for test_name in tests}
104+
)
105+
106+
for test_name in all_tests:
107+
log.debug(f"Test: {test_name}")
108+
for timewindow, tests in experiment_registry.results.items():
109+
if test_name in tests:
110+
models = tests[test_name]
111+
missing_models = []
112+
113+
for model_name, result_path in models.items():
114+
total_results += 1
115+
result_full_path = experiment_registry.get_result_key(timewindow, test_name, model_name)
116+
if os.path.exists(result_full_path):
117+
results_exist_count += 1
118+
else:
119+
results_not_exist_count += 1
120+
missing_models.append(model_name)
121+
122+
if not missing_models:
123+
log.debug(f" Time Window: {timewindow} - All models evaluated.")
124+
else:
125+
log.debug(
126+
f" Time Window: {timewindow} - Missing results for models: "
127+
f"{', '.join(missing_models)}"
128+
)
129+
130+
log.debug(f"Total Results: {total_results}")
131+
log.debug(f"Results that Exist: {results_exist_count}")
132+
log.debug(f"Results that Do Not Exist: {results_not_exist_count}")
133+
log.debug("===================")

0 commit comments

Comments
 (0)