Skip to content

Commit 714424d

Browse files
Internal.
PiperOrigin-RevId: 861886966
1 parent f0eab7b commit 714424d

1 file changed

Lines changed: 10 additions & 9 deletions

File tree

src/maxtext/trainers/pre_train/train.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@
5858
)
5959
from maxtext.common.gcloud_stub import cloud_diagnostics as _cloud_diag, is_decoupled
6060
from maxtext.common.gcloud_stub import vertex_tensorboard_modules
61-
from maxtext.common.metric_logger import MetricLogger, record_activation_metrics
61+
from maxtext.common import metric_logger
62+
from maxtext.common.metric_logger import record_activation_metrics
6263
from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn
6364
from maxtext.utils import exceptions
6465
from maxtext.utils import gcs_utils
@@ -570,10 +571,10 @@ def train_loop(config, recorder, state=None):
570571

571572
start_step = get_first_step(model, state) # this is the start_step for training
572573
prof = profiler.Profiler(config, offset_step=start_step)
573-
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
574+
metric_logger_instance = metric_logger.MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
574575

575576
# Write train config params, num model params, and XLA flags to tensorboard
576-
metric_logger.write_setup_info_to_tensorboard(state.params)
577+
metric_logger_instance.write_setup_info_to_tensorboard(state.params)
577578

578579
_job_completed_gracefully = False
579580
try:
@@ -611,7 +612,7 @@ def train_loop(config, recorder, state=None):
611612
assert eval_data_iterator
612613
# Explicitly reset the eval iterator and counters before starting the eval loop
613614
eval_data_iterator.reset()
614-
metric_logger.reset_eval_metrics()
615+
metric_logger_instance.reset_eval_metrics()
615616

616617
eval_step_count = 0
617618
# pylint: disable=not-callable
@@ -622,11 +623,11 @@ def train_loop(config, recorder, state=None):
622623
break
623624
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
624625
eval_metrics = p_eval_step(state, eval_batch, nextrng)
625-
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
626+
metric_logger_instance.record_eval_metrics(step, metrics=eval_metrics)
626627
max_logging.log(f"Completed eval step {eval_step_count}")
627628
eval_step_count += 1
628-
metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count)
629-
if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss:
629+
metric_logger_instance.record_eval_metrics(step, eval_step_count=eval_step_count)
630+
if metric_logger_instance.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss:
630631
prof.deactivate()
631632
raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.")
632633

@@ -635,7 +636,7 @@ def train_loop(config, recorder, state=None):
635636
if step == start_step:
636637
max_utils.print_mem_stats("After params initialized")
637638

638-
metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta)
639+
metric_logger_instance.buffer_and_write_train_metrics(metrics, step, step_time_delta)
639640

640641
if config.save_checkpoint_on_completion:
641642
state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0]
@@ -650,7 +651,7 @@ def train_loop(config, recorder, state=None):
650651
finally:
651652
if _job_completed_gracefully:
652653
record_goodput(recorder, RECORD_JOB_END_TIME)
653-
metric_logger.flush_metrics_and_cleanup()
654+
metric_logger_instance.flush_metrics_and_cleanup()
654655

655656
return state
656657

0 commit comments

Comments
 (0)