@@ -21,12 +21,8 @@ class MetricLoggerCallback(Callback):
2121 def __init__ (self , save_json : bool = True , ** kwargs ) -> None :
2222 super ().__init__ ()
2323 self .save_json = save_json
24- if self .save_json :
25- self .save_dir = kwargs .get ("save_dir" , None )
26- if self .save_dir is not None :
27- self .save_dir = osp .join (self .save_dir , "Validation" )
28- os .makedirs (self .save_dir , exist_ok = True )
29- self .output_dict = []
24+ self .save_dir = kwargs .get ("save_dir" , None )
25+ self .output_dict = []
3026
3127 def on_validation_end (
3228 self , trainer : pl .Trainer , pl_module : pl .LightningModule
@@ -64,14 +60,57 @@ def on_validation_end(
6460
6561 LOGGER .log (log_dict )
6662
63+ def on_test_end (self , trainer : pl .Trainer , pl_module : pl .LightningModule ) -> None :
64+ """
65+ After finish validation
66+ """
67+ iters = trainer .global_step
68+ metric_dict = pl_module .metric_dict
69+
70+ # Save json
71+ if self .save_json :
72+ item = {}
73+ for metric , score in metric_dict .items ():
74+ if isinstance (score , (int , float )):
75+ item [metric ] = float (f"{ score :.5f} " )
76+ if len (item .keys ()) > 0 :
77+ item ["iters" ] = iters
78+ self .output_dict .append (item )
79+
80+ # Log metric
81+ metric_string = ""
82+ for metric , score in metric_dict .items ():
83+ if isinstance (score , (int , float )):
84+ metric_string += metric + ": " + f"{ score :.5f} " + " | "
85+ metric_string += "\n "
86+
87+ LOGGER .text (metric_string , level = LoggerObserver .INFO )
88+
89+ # Call other loggers
90+ log_dict = [
91+ {"tag" : f"Test/{ k } " , "value" : v , "kwargs" : {"step" : iters }}
92+ for k , v in metric_dict .items ()
93+ ]
94+
95+ LOGGER .log (log_dict )
96+
6797 def teardown (
6898 self , trainer : pl .Trainer , pl_module : pl .LightningModule , stage : str
6999 ) -> None :
70100 """
71101 After finish everything
72102 """
103+
73104 if self .save_json :
74- save_json = osp .join (self .save_dir , "metrics.json" )
75- if len (self .output_dict ) > 0 :
76- with open (save_json , "w" ) as f :
77- json .dump (self .output_dict , f )
105+ if self .save_dir is not None :
106+ save_dir = osp .join (self .save_dir , stage .capitalize ())
107+ os .makedirs (save_dir , exist_ok = True )
108+ save_json = osp .join (save_dir , "metrics.json" )
109+ if len (self .output_dict ) > 0 :
110+ with open (save_json , "w" ) as f :
111+ json .dump (
112+ self .output_dict ,
113+ f ,
114+ indent = 4 ,
115+ default = lambda x : "<not serializable>" ,
116+ )
0 commit comments