diff --git a/axlearn/common/causal_lm.py b/axlearn/common/causal_lm.py index 63cbb99b4..6f0e66edf 100644 --- a/axlearn/common/causal_lm.py +++ b/axlearn/common/causal_lm.py @@ -187,12 +187,18 @@ def forward( ctx = ctx.parent module_outputs = ctx.get_module_outputs() + for k, v in flatten_items(module_outputs): + if re.fullmatch(regex, k): + logging.info("aux loss found at %s", k) + else: + logging.info("aux loss not found at %s", k) accumulation = list( v.mean() for k, v in flatten_items(module_outputs) if re.fullmatch(regex, k) ) if accumulation: aux_loss = sum(accumulation) / len(accumulation) else: + logging.warning("aux loss not found: %s", cfg.aux_loss_regex) aux_loss = 0.0 self.add_summary("aux_loss", WeightedScalar(aux_loss, num_targets))