Skip to content

Commit 2929d58

Browse files
committed
adding gradient norm tracking when grad clipping is enabled
1 parent 8369d44 commit 2929d58

File tree

5 files changed

+65
-5
lines changed

5 files changed

+65
-5
lines changed

makani/utils/training/autoencoder_trainer.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,11 @@ def train_one_epoch(self, profiler=None):
491491
# we need this for the loss average
492492
accumulated_loss = torch.zeros((2), dtype=torch.float32, device=self.device)
493493

494+
if self.max_grad_norm > 0.0:
495+
accumulated_grad_norm = torch.zeros((2), dtype=torch.float32, device=self.device, requires_grad=False)
496+
else:
497+
accumulated_grad_norm = None
498+
494499
train_steps = 0
495500
train_start = time.perf_counter_ns()
496501
self.model_train.zero_grad(set_to_none=True)
@@ -535,7 +540,9 @@ def train_one_epoch(self, profiler=None):
535540
if do_update:
536541
if self.max_grad_norm > 0.0:
537542
self.gscaler.unscale_(self.optimizer)
538-
clip_grads(self.model_train, self.max_grad_norm)
543+
grad_norm = clip_grads(self.model_train, self.max_grad_norm)
544+
accumulated_grad_norm[0] += grad_norm.detach()
545+
accumulated_grad_norm[1] += 1.0
539546

540547
self.gscaler.step(self.optimizer)
541548
self.gscaler.update()
@@ -570,6 +577,11 @@ def train_one_epoch(self, profiler=None):
570577
# add train steps to log
571578
logs["train_steps"] = train_steps
572579

580+
# log gradient norm
581+
if accumulated_grad_norm is not None:
582+
grad_norm = accumulated_grad_norm[0] / accumulated_grad_norm[1]
583+
logs["gradient norm"] = grad_norm.item()
584+
573585
# global sync is in order
574586
if dist.is_initialized():
575587
dist.barrier(device_ids=[self.device.index])
@@ -719,6 +731,9 @@ def get_pad(nchar):
719731
# validation summary
720732
self.logger.info("Metrics:")
721733
self.logger.info(print_prefix + "training loss: {}{}".format(get_pad(pad_len[0]), train_logs["loss"]))
734+
if "gradient norm" in train_logs:
735+
plen = max_len - len("gradient norm")
736+
self.logger.info(print_prefix + "gradient norm: {}{}".format(get_pad(plen), train_logs["gradient norm"]))
722737
self.logger.info(print_prefix + "validation loss: {}{}".format(get_pad(pad_len[1]), valid_logs["base"]["validation loss"]))
723738
for idk, key in enumerate(print_list[3:], start=3):
724739
value = valid_logs["metrics"][key]

makani/utils/training/deterministic_trainer.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,11 @@ def train_one_epoch(self, profiler=None):
490490
# we need this for the loss average
491491
accumulated_loss = torch.zeros((2), dtype=torch.float32, device=self.device, requires_grad=False)
492492

493+
if self.max_grad_norm > 0.0:
494+
accumulated_grad_norm = torch.zeros((2), dtype=torch.float32, device=self.device, requires_grad=False)
495+
else:
496+
accumulated_grad_norm = None
497+
493498
train_steps = 0
494499
train_start = time.perf_counter_ns()
495500
self.model_train.zero_grad(set_to_none=True)
@@ -543,7 +548,9 @@ def train_one_epoch(self, profiler=None):
543548
if do_update:
544549
if self.max_grad_norm > 0.0:
545550
self.gscaler.unscale_(self.optimizer)
546-
clip_grads(self.model_train, self.max_grad_norm)
551+
grad_norm = clip_grads(self.model_train, self.max_grad_norm)
552+
accumulated_grad_norm[0] += grad_norm.detach()
553+
accumulated_grad_norm[1] += 1.0
547554

548555
self.gscaler.step(self.optimizer)
549556
self.gscaler.update()
@@ -581,6 +588,11 @@ def train_one_epoch(self, profiler=None):
581588
# add train steps to log
582589
logs["train_steps"] = train_steps
583590

591+
# log gradient norm
592+
if accumulated_grad_norm is not None:
593+
grad_norm = accumulated_grad_norm[0] / accumulated_grad_norm[1]
594+
logs["gradient norm"] = grad_norm.item()
595+
584596
# global sync is in order
585597
if dist.is_initialized():
586598
dist.barrier(device_ids=[self.device.index])
@@ -725,6 +737,9 @@ def get_pad(nchar):
725737
# validation summary
726738
self.logger.info("Metrics:")
727739
self.logger.info(print_prefix + "training loss: {}{}".format(get_pad(pad_len[0]), train_logs["loss"]))
740+
if "gradient norm" in train_logs:
741+
plen = max_len - len("gradient norm")
742+
self.logger.info(print_prefix + "gradient norm: {}{}".format(get_pad(plen), train_logs["gradient norm"]))
728743
self.logger.info(print_prefix + "validation loss: {}{}".format(get_pad(pad_len[1]), valid_logs["base"]["validation loss"]))
729744
for idk, key in enumerate(print_list[3:], start=3):
730745
value = valid_logs["metrics"][key]

makani/utils/training/ensemble_trainer.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,11 @@ def train_one_epoch(self, profiler=None):
503503
# we need this for the loss average
504504
accumulated_loss = torch.zeros((2), dtype=torch.float32, device=self.device)
505505

506+
if self.max_grad_norm > 0.0:
507+
accumulated_grad_norm = torch.zeros((2), dtype=torch.float32, device=self.device)
508+
else:
509+
accumulated_grad_norm = None
510+
506511
train_steps = 0
507512
train_start = time.perf_counter_ns()
508513
self.model_train.zero_grad(set_to_none=True)
@@ -552,7 +557,9 @@ def train_one_epoch(self, profiler=None):
552557
if do_update:
553558
if self.max_grad_norm > 0.0:
554559
self.gscaler.unscale_(self.model_optimizer)
555-
clip_grads(self.model_train, self.max_grad_norm)
560+
grad_norm = clip_grads(self.model_train, self.max_grad_norm)
561+
accumulated_grad_norm[0] += grad_norm.detach()
562+
accumulated_grad_norm[1] += 1.0
556563

557564
self.gscaler.step(self.model_optimizer)
558565
self.gscaler.update()
@@ -590,6 +597,11 @@ def train_one_epoch(self, profiler=None):
590597
# add train steps to log
591598
logs["train_steps"] = train_steps
592599

600+
# log gradient norm
601+
if accumulated_grad_norm is not None:
602+
grad_norm = accumulated_grad_norm[0] / accumulated_grad_norm[1]
603+
logs["gradient norm"] = grad_norm.item()
604+
593605
# global sync is in order
594606
if dist.is_initialized():
595607
dist.barrier(device_ids=[self.device.index])
@@ -774,6 +786,9 @@ def get_pad(nchar):
774786
# validation summary
775787
self.logger.info("Metrics:")
776788
self.logger.info(print_prefix + "training loss: {}{}".format(get_pad(pad_len[0]), train_logs["loss"]))
789+
if "gradient norm" in train_logs:
790+
plen = max_len - len("gradient norm")
791+
self.logger.info(print_prefix + "gradient norm: {}{}".format(get_pad(plen), train_logs["gradient norm"]))
777792
self.logger.info(print_prefix + "validation loss: {}{}".format(get_pad(pad_len[1]), valid_logs["base"]["validation loss"]))
778793
for idk, key in enumerate(print_list[3:], start=3):
779794
value = valid_logs["metrics"][key]

makani/utils/training/stochastic_trainer.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,11 @@ def train_one_epoch(self):
472472
# we need this for the loss average
473473
accumulated_loss = torch.zeros((2), dtype=torch.float32, device=self.device)
474474

475+
if self.max_grad_norm > 0.0:
476+
accumulated_grad_norm = torch.zeros((2), dtype=torch.float32, device=self.device, requires_grad=False)
477+
else:
478+
accumulated_grad_norm = None
479+
475480
train_steps = 0
476481
train_start = time.perf_counter_ns()
477482
self.model_train.zero_grad(set_to_none=True)
@@ -518,7 +523,9 @@ def train_one_epoch(self):
518523
if do_update:
519524
if self.max_grad_norm > 0.0:
520525
self.gscaler.unscale_(self.optimizer)
521-
clip_grads(self.model_train, self.max_grad_norm)
526+
grad_norm = clip_grads(self.model_train, self.max_grad_norm)
527+
accumulated_grad_norm[0] += grad_norm.detach()
528+
accumulated_grad_norm[1] += 1.0
522529

523530
# perform weight update
524531
self.gscaler.step(self.optimizer)
@@ -556,6 +563,11 @@ def train_one_epoch(self):
556563
# add train steps to log
557564
logs["train_steps"] = train_steps
558565

566+
# log gradient norm
567+
if accumulated_grad_norm is not None:
568+
grad_norm = accumulated_grad_norm[0] / accumulated_grad_norm[1]
569+
logs["gradient norm"] = grad_norm.item()
570+
559571
# global sync is in order
560572
if dist.is_initialized():
561573
dist.barrier(device_ids=[self.device.index])
@@ -710,6 +722,9 @@ def get_pad(nchar):
710722
# validation summary
711723
self.logger.info("Metrics:")
712724
self.logger.info(print_prefix + "training loss: {}{}".format(get_pad(pad_len[0]), train_logs["loss"]))
725+
if "gradient norm" in train_logs:
726+
plen = max_len - len("gradient norm")
727+
self.logger.info(print_prefix + "gradient norm: {}{}".format(get_pad(plen), train_logs["gradient norm"]))
713728
self.logger.info(print_prefix + "validation loss: {}{}".format(get_pad(pad_len[1]), valid_logs["base"]["validation loss"]))
714729
for idk, key in enumerate(print_list[3:], start=3):
715730
value = valid_logs["metrics"][key]

makani/utils/training/training_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def clip_grads(model, max_grad_norm, norm_type=2.0):
9999

100100
param.grad.mul_(clip_factor)
101101

102-
return
102+
return total_gnorm
103103

104104

105105
def wandb_register_activations_monitor(model: nn.Module, step: int):

0 commit comments

Comments
 (0)