Skip to content

Commit 19c76b7

Browse files
committed
Also log memory util and step timings to command line
1 parent c47f9d5 commit 19c76b7

1 file changed

Lines changed: 17 additions & 7 deletions

File tree

gluefactory/trainer.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,13 @@ def current_it(self):
450450
"""Get the current iteration identifier."""
451451
return self.tot_it if self.conf.log_it else self.tot_n_samples
452452

453-
def log_train(self, writer: Writer, it: int, train_loss_metrics: LossMetrics):
453+
def log_train(
454+
self,
455+
writer: Writer,
456+
it: int,
457+
train_loss_metrics: LossMetrics,
458+
extra_str: str = "",
459+
):
454460
tot_n_samples = self.current_it
455461
all_params = self.all_params
456462
writer.add_scalar("l2/param_norm", misc.param_norm(all_params), tot_n_samples)
@@ -459,8 +465,8 @@ def log_train(self, writer: Writer, it: int, train_loss_metrics: LossMetrics):
459465
str_loss_metrics = [f"{k} {v:.3E}" for k, v in loss_metrics.items()]
460466
# Write training losses
461467
logger.info(
462-
"[E {} | it {}] loss {{{}}}".format(
463-
self.epoch, it, ", ".join(str_loss_metrics)
468+
"[E {} | it {}] {} loss {{{}}}".format(
469+
self.epoch, it, extra_str, ", ".join(str_loss_metrics)
464470
)
465471
)
466472
tools.write_dict_summaries(writer, "training", loss_metrics, tot_n_samples)
@@ -487,7 +493,7 @@ def log_time_and_memory(
487493
self,
488494
writer: Writer,
489495
batch_size: int,
490-
):
496+
) -> str:
491497
tot_n_samples = self.current_it
492498
steps_per_sec = 0.0
493499
if self.step_timer.num_steps() > 1:
@@ -528,7 +534,7 @@ def log_time_and_memory(
528534
memory_total = device_stats["global_total"]
529535
tools.write_dict_summaries(writer, "memory", device_stats, tot_n_samples)
530536

531-
self.info(
537+
return (
532538
f"[Used {memory_used:.1f}/{memory_total:.1f} GB | {steps_per_sec:.1f} it/s]"
533539
)
534540

@@ -672,9 +678,13 @@ def train_epoch(
672678

673679
# Log training metrics (loss, ...) and hardware usage
674680
if (it % self.conf.log_every_iter == 0) and self.rank == 0:
675-
self.log_train(writer, it, train_loss_metrics)
681+
time_and_mem_str = self.log_time_and_memory(
682+
writer, dataloader.batch_size
683+
)
684+
self.log_train(
685+
writer, it, train_loss_metrics, extra_str=time_and_mem_str
686+
)
676687
train_loss_metrics.clear() # Reset training loss aggregators
677-
self.log_time_and_memory(writer, dataloader.batch_size)
678688

679689
# Make plots of training steps
680690
if self.conf.plot_every_iter is not None:

0 commit comments

Comments
 (0)