diff --git a/tunix/sft/peft_trainer.py b/tunix/sft/peft_trainer.py index 869be6a0f..2357a0e65 100644 --- a/tunix/sft/peft_trainer.py +++ b/tunix/sft/peft_trainer.py @@ -41,7 +41,6 @@ from tunix.sft import profiler from tunix.sft import progress_bar from tunix.sft import sharding_utils -from tunix.sft import system_metrics_calculator from tunix.sft import utils _ModelInputT = Dict[str, ArrayLike] @@ -233,7 +232,6 @@ def __init__( self._mode: sft_metrics_logger.Mode = sft_metrics_logger.Mode.TRAIN self._has_aux = False self._pbar = None - self._flops_measured: bool = False self._train_steps, self._restored_custom_metadata = ( self.checkpoint_manager.maybe_restore( @@ -659,24 +657,6 @@ def train( train_example, self.config.data_sharding_axis ) - if not self._flops_measured and not skip_jit: - self._flops_measured = True - - tflops_per_step = system_metrics_calculator.measure_tflops_per_step( - train_step_fn=train_step, - model=self.model, - optimizer=self.optimizer, - train_example=train_example, - ) - if tflops_per_step is not None: - self.metrics_logger.log( - self.metrics_prefix, - "tflops_per_step", - tflops_per_step, - self._mode, - 0, - ) - self._throttler.wait_for_next() if self.training_hooks: self.training_hooks.on_train_step_start(self)