Skip to content

Commit 379f194

Browse files
committed
cp
1 parent 08118ec commit 379f194

File tree

2 files changed

+69
-54
lines changed

2 files changed

+69
-54
lines changed

slime/backends/megatron_utils/data.py

Lines changed: 13 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
from megatron.core import mpu
1010
from megatron.core.packed_seq_params import PackedSeqParams
1111

12+
from slime.utils import train_metric_utils
1213
from slime.utils.data import get_minimum_num_micro_batch_size
1314
from slime.utils.flops_utils import calculate_fwd_flops
1415
from slime.utils.metric_utils import compute_pass_rate
1516
from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions
16-
from slime.utils.timer import Timer
1717
from slime.utils.types import RolloutBatch
1818

1919
from .cp_utils import get_sum_of_sample_mean, slice_with_cp
@@ -419,59 +419,18 @@ def log_passrate(rollout_id: int, args: Namespace, rollout_data: RolloutBatch) -
419419

420420

421421
def log_perf_data(rollout_id: int, args: Namespace) -> None:
422-
"""
423-
Log timing metrics and derived TFLOPs for compute phases if available.
424-
425-
Only active on PP last stage, TP rank 0, and DP source rank. The step is
426-
consistent with other logs.
427-
"""
428-
timer_instance = Timer()
429-
if (
430-
mpu.get_tensor_model_parallel_rank() == 0
431-
and mpu.is_pipeline_last_stage()
432-
and mpu.get_data_parallel_rank(with_context_parallel=True) == 0
433-
):
434-
log_dict = {f"perf/{key}_time": val for key, val in timer_instance.log_dict().items()}
435-
436-
if "perf/actor_train_time" in log_dict:
437-
world_size = dist.get_world_size()
438-
total_fwd_flops = calculate_fwd_flops(seqlens=timer_instance.seq_lens, args=args) / world_size / 1e12
439-
440-
if "perf/log_probs_time" in log_dict:
441-
log_dict["perf/log_probs_tflops"] = total_fwd_flops / log_dict["perf/log_probs_time"]
442-
443-
if "perf/ref_log_probs_time" in log_dict:
444-
log_dict["perf/ref_log_probs_tflops"] = total_fwd_flops / log_dict["perf/ref_log_probs_time"]
445-
446-
if log_dict["perf/actor_train_time"] > 0:
447-
log_dict["perf/actor_train_tflops"] = 3 * total_fwd_flops / log_dict["perf/actor_train_time"]
448-
log_dict["perf/actor_train_tok_per_s"] = (
449-
sum(timer_instance.seq_lens) / log_dict["perf/actor_train_time"]
450-
)
451-
452-
if "perf/train_wait_time" in log_dict and "perf/train_time" in log_dict:
453-
total_time = log_dict["perf/train_wait_time"] + log_dict["perf/train_time"]
454-
if total_time > 0:
455-
log_dict["perf/total_train_time"] = total_time
456-
log_dict["perf/wait_time_ratio"] = log_dict["perf/train_wait_time"] / total_time
457-
458-
print(f"perf {rollout_id}: {log_dict}")
459-
460-
step = (
461-
rollout_id
462-
if not args.wandb_always_use_train_step
463-
else rollout_id * args.rollout_batch_size * args.n_samples_per_prompt // args.global_batch_size
464-
)
465-
if args.use_wandb:
466-
log_dict["rollout/step"] = step
467-
wandb.log(log_dict)
468-
469-
if args.use_tensorboard:
470-
from slime.utils.tensorboard_utils import _TensorboardAdapter
471-
472-
tb = _TensorboardAdapter(args)
473-
tb.log(data=log_dict, step=step)
474-
timer_instance.reset()
422+
train_metric_utils.log_perf_data_raw(
423+
rollout_id=rollout_id,
424+
args=args,
425+
is_primary_rank=(
426+
mpu.get_tensor_model_parallel_rank() == 0
427+
and mpu.is_pipeline_last_stage()
428+
and mpu.get_data_parallel_rank(with_context_parallel=True) == 0
429+
),
430+
compute_total_fwd_flops=lambda seq_lens: calculate_fwd_flops(seqlens=seq_lens, args=args)
431+
/ dist.get_world_size()
432+
/ 1e12,
433+
)
475434

476435

477436
def sync_actor_critic_data(

slime/utils/train_metric_utils.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from argparse import Namespace
2+
from copy import deepcopy
3+
from typing import Callable
4+
5+
import wandb
6+
7+
from slime.utils.timer import Timer
8+
9+
10+
def log_perf_data_raw(
11+
rollout_id: int, args: Namespace, is_primary_rank: bool, compute_total_fwd_flops: Callable
12+
) -> None:
13+
timer_instance = Timer()
14+
log_dict_raw = deepcopy(timer_instance.log_dict())
15+
timer_instance.reset()
16+
17+
if not is_primary_rank:
18+
return
19+
20+
log_dict = {f"perf/{key}_time": val for key, val in log_dict_raw.items()}
21+
22+
if ("perf/actor_train_time" in log_dict) and (compute_total_fwd_flops is not None):
23+
total_fwd_flops = compute_total_fwd_flops(seq_lens=timer_instance.seq_lens)
24+
25+
if "perf/log_probs_time" in log_dict:
26+
log_dict["perf/log_probs_tflops"] = total_fwd_flops / log_dict["perf/log_probs_time"]
27+
28+
if "perf/ref_log_probs_time" in log_dict:
29+
log_dict["perf/ref_log_probs_tflops"] = total_fwd_flops / log_dict["perf/ref_log_probs_time"]
30+
31+
if log_dict["perf/actor_train_time"] > 0:
32+
log_dict["perf/actor_train_tflops"] = 3 * total_fwd_flops / log_dict["perf/actor_train_time"]
33+
log_dict["perf/actor_train_tok_per_s"] = sum(timer_instance.seq_lens) / log_dict["perf/actor_train_time"]
34+
35+
if "perf/train_wait_time" in log_dict and "perf/train_time" in log_dict:
36+
total_time = log_dict["perf/train_wait_time"] + log_dict["perf/train_time"]
37+
if total_time > 0:
38+
log_dict["perf/step_time"] = total_time
39+
log_dict["perf/wait_time_ratio"] = log_dict["perf/train_wait_time"] / total_time
40+
41+
print(f"perf {rollout_id}: {log_dict}")
42+
43+
step = (
44+
rollout_id
45+
if not args.wandb_always_use_train_step
46+
else rollout_id * args.rollout_batch_size * args.n_samples_per_prompt // args.global_batch_size
47+
)
48+
if args.use_wandb:
49+
log_dict["rollout/step"] = step
50+
wandb.log(log_dict)
51+
52+
if args.use_tensorboard:
53+
from slime.utils.tensorboard_utils import _TensorboardAdapter
54+
55+
tb = _TensorboardAdapter(args)
56+
tb.log(data=log_dict, step=step)

0 commit comments

Comments
 (0)