25
25
parse_nested_dicts ,
26
26
)
27
27
from floatcsep .infrastructure .engine import Task , TaskGraph
28
+ from floatcsep .infrastructure .logger import log_models_tree , log_results_tree
28
29
29
30
log = logging .getLogger ("floatLogger" )
30
31
@@ -118,7 +119,7 @@ def __init__(
118
119
os .makedirs (os .path .join (workdir , rundir ), exist_ok = True )
119
120
120
121
self .name = name if name else "floatingExp"
121
- self .registry = ExperimentRegistry (workdir , rundir )
122
+ self .registry = ExperimentRegistry . factory (workdir = workdir , run_dir = rundir )
122
123
self .results_repo = ResultsRepository (self .registry )
123
124
self .catalog_repo = CatalogRepository (self .registry )
124
125
@@ -296,7 +297,7 @@ def stage_models(self) -> None:
296
297
log .info ("Staging models" )
297
298
for i in self .models :
298
299
i .stage (self .time_windows )
299
- self .registry .add_forecast_registry (i )
300
+ self .registry .add_model_registry (i )
300
301
301
302
def set_tests (self , test_config : Union [str , Dict , List ]) -> list :
302
303
"""
@@ -379,9 +380,9 @@ def set_tasks(self) -> None:
379
380
self .registry .build_tree (self .time_windows , self .models , self .tests )
380
381
381
382
log .debug ("Pre-run forecast summary" )
382
- self .registry . log_forecast_trees ( self .time_windows )
383
+ log_models_tree ( log , self .registry , self .time_windows )
383
384
log .debug ("Pre-run result summary" )
384
- self .registry . log_results_tree ( )
385
+ log_results_tree ( log , self .registry )
385
386
386
387
log .info ("Setting up experiment's tasks" )
387
388
@@ -540,9 +541,9 @@ def run(self) -> None:
540
541
self .task_graph .run ()
541
542
log .info ("Calculation completed" )
542
543
log .debug ("Post-run forecast registry" )
543
- self .registry . log_forecast_trees ( self .time_windows )
544
+ log_models_tree ( log , self .registry , self .time_windows )
544
545
log .debug ("Post-run result summary" )
545
- self .registry . log_results_tree ( )
546
+ log_results_tree ( log , self .registry )
546
547
547
548
def read_results (self , test : Evaluation , window : str ) -> List :
548
549
"""
@@ -559,7 +560,7 @@ def make_repr(self) -> None:
559
560
560
561
"""
561
562
log .info ("Creating reproducibility config file" )
562
- repr_config = self .registry .get ("repr_config" )
563
+ repr_config = self .registry .get_attr ("repr_config" )
563
564
564
565
# Dropping region to results folder if it is a file
565
566
region_path = self .region_config .get ("path" , False )
@@ -801,8 +802,8 @@ def get_filecomp(self):
801
802
for tw in win_orig :
802
803
results [test .name ][tw ] = dict .fromkeys (models_orig )
803
804
for model in models_orig :
804
- orig_path = self .original .registry .get_result (tw , test , model )
805
- repr_path = self .reproduced .registry .get_result (tw , test , model )
805
+ orig_path = self .original .registry .get_result_key (tw , test , model )
806
+ repr_path = self .reproduced .registry .get_result_key (tw , test , model )
806
807
807
808
results [test .name ][tw ][model ] = {
808
809
"hash" : (self .get_hash (orig_path ) == self .get_hash (repr_path )),
@@ -811,8 +812,8 @@ def get_filecomp(self):
811
812
else :
812
813
results [test .name ] = dict .fromkeys (models_orig )
813
814
for model in models_orig :
814
- orig_path = self .original .registry .get_result (win_orig [- 1 ], test , model )
815
- repr_path = self .reproduced .registry .get_result (win_orig [- 1 ], test , model )
815
+ orig_path = self .original .registry .get_result_key (win_orig [- 1 ], test , model )
816
+ repr_path = self .reproduced .registry .get_result_key (win_orig [- 1 ], test , model )
816
817
results [test .name ][model ] = {
817
818
"hash" : (self .get_hash (orig_path ) == self .get_hash (repr_path )),
818
819
"byte2byte" : filecmp .cmp (orig_path , repr_path ),
0 commit comments