Skip to content

Commit 2d1b736

Browse files
tianshubThe tunix Authors
authored andcommitted
fix metric logging step
PiperOrigin-RevId: 876360948
1 parent df627a6 commit 2d1b736

File tree

2 files changed

+2
-24
lines changed

2 files changed

+2
-24
lines changed

tunix/rl/experimental/agentic_rl_learner.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -511,10 +511,8 @@ def _batch_to_train_example(
511511
"""
512512
# Create a merged training_input where each field from the original input
513513
# is repeated G times to align with the G completions.
514-
num_generations = self.algo_config.num_generations
515-
prompt_index = batch_results[0].pair_index // num_generations
516-
if mode == rl_cluster_lib.Mode.TRAIN and self._full_batch_size:
517-
expected_step = prompt_index // self._full_batch_size
514+
if mode == rl_cluster_lib.Mode.TRAIN:
515+
expected_step = batch_results[0].group_id // self._full_batch_size
518516
else:
519517
expected_step = self.rl_cluster.global_steps
520518

tunix/sft/peft_trainer.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from tunix.sft import profiler
4242
from tunix.sft import progress_bar
4343
from tunix.sft import sharding_utils
44-
from tunix.sft import system_metrics_calculator
4544
from tunix.sft import utils
4645

4746
_ModelInputT = Dict[str, ArrayLike]
@@ -233,7 +232,6 @@ def __init__(
233232
self._mode: sft_metrics_logger.Mode = sft_metrics_logger.Mode.TRAIN
234233
self._has_aux = False
235234
self._pbar = None
236-
self._flops_measured: bool = False
237235

238236
self._train_steps, self._restored_custom_metadata = (
239237
self.checkpoint_manager.maybe_restore(
@@ -659,24 +657,6 @@ def train(
659657
train_example, self.config.data_sharding_axis
660658
)
661659

662-
if not self._flops_measured and not skip_jit:
663-
self._flops_measured = True
664-
665-
tflops_per_step = system_metrics_calculator.measure_tflops_per_step(
666-
train_step_fn=train_step,
667-
model=self.model,
668-
optimizer=self.optimizer,
669-
train_example=train_example,
670-
)
671-
if tflops_per_step is not None:
672-
self.metrics_logger.log(
673-
self.metrics_prefix,
674-
"tflops_per_step",
675-
tflops_per_step,
676-
self._mode,
677-
0,
678-
)
679-
680660
self._throttler.wait_for_next()
681661
if self.training_hooks:
682662
self.training_hooks.on_train_step_start(self)

0 commit comments

Comments
 (0)