5858)
5959from maxtext .common .gcloud_stub import cloud_diagnostics as _cloud_diag , is_decoupled
6060from maxtext .common .gcloud_stub import vertex_tensorboard_modules
61- from maxtext .common .metric_logger import MetricLogger , record_activation_metrics
61+ from maxtext .common import metric_logger
62+ from maxtext .common .metric_logger import record_activation_metrics
6263from maxtext .trainers .post_train .dpo .dpo_utils import _merge_dpo_state , _split_dpo_state , dpo_loss_fn
6364from maxtext .utils import exceptions
6465from maxtext .utils import gcs_utils
@@ -570,10 +571,10 @@ def train_loop(config, recorder, state=None):
570571
571572 start_step = get_first_step (model , state ) # this is the start_step for training
572573 prof = profiler .Profiler (config , offset_step = start_step )
573- metric_logger = MetricLogger (config = config , learning_rate_schedule = learning_rate_schedule )
574+ metric_logger_instance = metric_logger . MetricLogger (config = config , learning_rate_schedule = learning_rate_schedule )
574575
575576 # Write train config params, num model params, and XLA flags to tensorboard
576- metric_logger .write_setup_info_to_tensorboard (state .params )
577+ metric_logger_instance .write_setup_info_to_tensorboard (state .params )
577578
578579 _job_completed_gracefully = False
579580 try :
@@ -611,7 +612,7 @@ def train_loop(config, recorder, state=None):
611612 assert eval_data_iterator
612613 # Explicitly reset the eval iterator and counters before starting the eval loop
613614 eval_data_iterator .reset ()
614- metric_logger .reset_eval_metrics ()
615+ metric_logger_instance .reset_eval_metrics ()
615616
616617 eval_step_count = 0
617618 # pylint: disable=not-callable
@@ -622,11 +623,11 @@ def train_loop(config, recorder, state=None):
622623 break
623624 with jax .set_mesh (mesh ), nn_partitioning .axis_rules (config .logical_axis_rules ):
624625 eval_metrics = p_eval_step (state , eval_batch , nextrng )
625- metric_logger .record_eval_metrics (step , metrics = eval_metrics )
626+ metric_logger_instance .record_eval_metrics (step , metrics = eval_metrics )
626627 max_logging .log (f"Completed eval step { eval_step_count } " )
627628 eval_step_count += 1
628- metric_logger .record_eval_metrics (step , eval_step_count = eval_step_count )
629- if metric_logger .cumulative_eval_metrics ["scalar" ]["eval/avg_loss" ] <= config .target_eval_loss :
629+ metric_logger_instance .record_eval_metrics (step , eval_step_count = eval_step_count )
630+ if metric_logger_instance .cumulative_eval_metrics ["scalar" ]["eval/avg_loss" ] <= config .target_eval_loss :
630631 prof .deactivate ()
631632 raise exceptions .StopTraining (f"Target loss { config .target_eval_loss = } is achieved." )
632633
@@ -635,7 +636,7 @@ def train_loop(config, recorder, state=None):
635636 if step == start_step :
636637 max_utils .print_mem_stats ("After params initialized" )
637638
638- metric_logger .buffer_and_write_train_metrics (metrics , step , step_time_delta )
639+ metric_logger_instance .buffer_and_write_train_metrics (metrics , step , step_time_delta )
639640
640641 if config .save_checkpoint_on_completion :
641642 state_to_save = state if not config .use_dpo else _split_dpo_state (state )[0 ]
@@ -650,7 +651,7 @@ def train_loop(config, recorder, state=None):
650651 finally :
651652 if _job_completed_gracefully :
652653 record_goodput (recorder , RECORD_JOB_END_TIME )
653- metric_logger .flush_metrics_and_cleanup ()
654+ metric_logger_instance .flush_metrics_and_cleanup ()
654655
655656 return state
656657
0 commit comments