@@ -450,7 +450,13 @@ def current_it(self):
450450 """Get the current iteration identifier."""
451451 return self .tot_it if self .conf .log_it else self .tot_n_samples
452452
453- def log_train (self , writer : Writer , it : int , train_loss_metrics : LossMetrics ):
453+ def log_train (
454+ self ,
455+ writer : Writer ,
456+ it : int ,
457+ train_loss_metrics : LossMetrics ,
458+ extra_str : str = "" ,
459+ ):
454460 tot_n_samples = self .current_it
455461 all_params = self .all_params
456462 writer .add_scalar ("l2/param_norm" , misc .param_norm (all_params ), tot_n_samples )
@@ -459,8 +465,8 @@ def log_train(self, writer: Writer, it: int, train_loss_metrics: LossMetrics):
459465 str_loss_metrics = [f"{ k } { v :.3E} " for k , v in loss_metrics .items ()]
460466 # Write training losses
461467 logger .info (
462- "[E {} | it {}] loss {{{}}}" .format (
463- self .epoch , it , ", " .join (str_loss_metrics )
468+ "[E {} | it {}] {} loss {{{}}}" .format (
469+ self .epoch , it , extra_str , ", " .join (str_loss_metrics )
464470 )
465471 )
466472 tools .write_dict_summaries (writer , "training" , loss_metrics , tot_n_samples )
@@ -487,7 +493,7 @@ def log_time_and_memory(
487493 self ,
488494 writer : Writer ,
489495 batch_size : int ,
490- ):
496+ ) -> str :
491497 tot_n_samples = self .current_it
492498 steps_per_sec = 0.0
493499 if self .step_timer .num_steps () > 1 :
@@ -528,7 +534,7 @@ def log_time_and_memory(
528534 memory_total = device_stats ["global_total" ]
529535 tools .write_dict_summaries (writer , "memory" , device_stats , tot_n_samples )
530536
531- self . info (
537+ return (
532538 f"[Used { memory_used :.1f} /{ memory_total :.1f} GB | { steps_per_sec :.1f} it/s]"
533539 )
534540
@@ -672,9 +678,13 @@ def train_epoch(
672678
673679 # Log training metrics (loss, ...) and hardware usage
674680 if (it % self .conf .log_every_iter == 0 ) and self .rank == 0 :
675- self .log_train (writer , it , train_loss_metrics )
681+ time_and_mem_str = self .log_time_and_memory (
682+ writer , dataloader .batch_size
683+ )
684+ self .log_train (
685+ writer , it , train_loss_metrics , extra_str = time_and_mem_str
686+ )
676687 train_loss_metrics .clear () # Reset training loss aggregators
677- self .log_time_and_memory (writer , dataloader .batch_size )
678688
679689 # Make plots of training steps
680690 if self .conf .plot_every_iter is not None :
0 commit comments