@@ -490,6 +490,11 @@ def train_one_epoch(self, profiler=None):
490490 # we need this for the loss average
491491 accumulated_loss = torch .zeros ((2 ), dtype = torch .float32 , device = self .device , requires_grad = False )
492492
493+ if self .max_grad_norm > 0.0 :
494+ accumulated_grad_norm = torch .zeros ((2 ), dtype = torch .float32 , device = self .device , requires_grad = False )
495+ else :
496+ accumulated_grad_norm = None
497+
493498 train_steps = 0
494499 train_start = time .perf_counter_ns ()
495500 self .model_train .zero_grad (set_to_none = True )
@@ -543,7 +548,9 @@ def train_one_epoch(self, profiler=None):
543548 if do_update :
544549 if self .max_grad_norm > 0.0 :
545550 self .gscaler .unscale_ (self .optimizer )
546- clip_grads (self .model_train , self .max_grad_norm )
551+ grad_norm = clip_grads (self .model_train , self .max_grad_norm )
552+ accumulated_grad_norm [0 ] += grad_norm .detach ()
553+ accumulated_grad_norm [1 ] += 1.0
547554
548555 self .gscaler .step (self .optimizer )
549556 self .gscaler .update ()
@@ -581,6 +588,11 @@ def train_one_epoch(self, profiler=None):
581588 # add train steps to log
582589 logs ["train_steps" ] = train_steps
583590
591+ # log gradient norm
592+ if accumulated_grad_norm is not None :
593+ grad_norm = accumulated_grad_norm [0 ] / accumulated_grad_norm [1 ]
594+ logs ["gradient norm" ] = grad_norm .item ()
595+
584596 # global sync is in order
585597 if dist .is_initialized ():
586598 dist .barrier (device_ids = [self .device .index ])
@@ -725,6 +737,9 @@ def get_pad(nchar):
725737 # validation summary
726738 self .logger .info ("Metrics:" )
727739 self .logger .info (print_prefix + "training loss: {}{}" .format (get_pad (pad_len [0 ]), train_logs ["loss" ]))
740+ if "gradient norm" in train_logs :
741+ plen = max_len - len ("gradient norm" )
742+ self .logger .info (print_prefix + "gradient norm: {}{}" .format (get_pad (plen ), train_logs ["gradient norm" ]))
728743 self .logger .info (print_prefix + "validation loss: {}{}" .format (get_pad (pad_len [1 ]), valid_logs ["base" ]["validation loss" ]))
729744 for idk , key in enumerate (print_list [3 :], start = 3 ):
730745 value = valid_logs ["metrics" ][key ]
0 commit comments