|
41 | 41 | from tunix.sft import profiler |
42 | 42 | from tunix.sft import progress_bar |
43 | 43 | from tunix.sft import sharding_utils |
44 | | -from tunix.sft import system_metrics_calculator |
45 | 44 | from tunix.sft import utils |
46 | 45 |
|
47 | 46 | _ModelInputT = Dict[str, ArrayLike] |
@@ -233,7 +232,6 @@ def __init__( |
233 | 232 | self._mode: sft_metrics_logger.Mode = sft_metrics_logger.Mode.TRAIN |
234 | 233 | self._has_aux = False |
235 | 234 | self._pbar = None |
236 | | - self._flops_measured: bool = False |
237 | 235 |
|
238 | 236 | self._train_steps, self._restored_custom_metadata = ( |
239 | 237 | self.checkpoint_manager.maybe_restore( |
@@ -659,24 +657,6 @@ def train( |
659 | 657 | train_example, self.config.data_sharding_axis |
660 | 658 | ) |
661 | 659 |
|
662 | | - if not self._flops_measured and not skip_jit: |
663 | | - self._flops_measured = True |
664 | | - |
665 | | - tflops_per_step = system_metrics_calculator.measure_tflops_per_step( |
666 | | - train_step_fn=train_step, |
667 | | - model=self.model, |
668 | | - optimizer=self.optimizer, |
669 | | - train_example=train_example, |
670 | | - ) |
671 | | - if tflops_per_step is not None: |
672 | | - self.metrics_logger.log( |
673 | | - self.metrics_prefix, |
674 | | - "tflops_per_step", |
675 | | - tflops_per_step, |
676 | | - self._mode, |
677 | | - 0, |
678 | | - ) |
679 | | - |
680 | 660 | self._throttler.wait_for_next() |
681 | 661 | if self.training_hooks: |
682 | 662 | self.training_hooks.on_train_step_start(self) |
|
0 commit comments