|
9 | 9 | from megatron.core import mpu |
10 | 10 | from megatron.core.packed_seq_params import PackedSeqParams |
11 | 11 |
|
| 12 | +from slime.utils import train_metric_utils |
12 | 13 | from slime.utils.data import get_minimum_num_micro_batch_size |
13 | 14 | from slime.utils.flops_utils import calculate_fwd_flops |
14 | 15 | from slime.utils.metric_utils import compute_pass_rate |
15 | 16 | from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions |
16 | | -from slime.utils.timer import Timer |
17 | 17 | from slime.utils.types import RolloutBatch |
18 | 18 |
|
19 | 19 | from .cp_utils import get_sum_of_sample_mean, slice_with_cp |
@@ -419,59 +419,18 @@ def log_passrate(rollout_id: int, args: Namespace, rollout_data: RolloutBatch) - |
419 | 419 |
|
420 | 420 |
|
421 | 421 | def log_perf_data(rollout_id: int, args: Namespace) -> None: |
422 | | - """ |
423 | | - Log timing metrics and derived TFLOPs for compute phases if available. |
424 | | -
|
425 | | - Only active on PP last stage, TP rank 0, and DP source rank. The step is |
426 | | - consistent with other logs. |
427 | | - """ |
428 | | - timer_instance = Timer() |
429 | | - if ( |
430 | | - mpu.get_tensor_model_parallel_rank() == 0 |
431 | | - and mpu.is_pipeline_last_stage() |
432 | | - and mpu.get_data_parallel_rank(with_context_parallel=True) == 0 |
433 | | - ): |
434 | | - log_dict = {f"perf/{key}_time": val for key, val in timer_instance.log_dict().items()} |
435 | | - |
436 | | - if "perf/actor_train_time" in log_dict: |
437 | | - world_size = dist.get_world_size() |
438 | | - total_fwd_flops = calculate_fwd_flops(seqlens=timer_instance.seq_lens, args=args) / world_size / 1e12 |
439 | | - |
440 | | - if "perf/log_probs_time" in log_dict: |
441 | | - log_dict["perf/log_probs_tflops"] = total_fwd_flops / log_dict["perf/log_probs_time"] |
442 | | - |
443 | | - if "perf/ref_log_probs_time" in log_dict: |
444 | | - log_dict["perf/ref_log_probs_tflops"] = total_fwd_flops / log_dict["perf/ref_log_probs_time"] |
445 | | - |
446 | | - if log_dict["perf/actor_train_time"] > 0: |
447 | | - log_dict["perf/actor_train_tflops"] = 3 * total_fwd_flops / log_dict["perf/actor_train_time"] |
448 | | - log_dict["perf/actor_train_tok_per_s"] = ( |
449 | | - sum(timer_instance.seq_lens) / log_dict["perf/actor_train_time"] |
450 | | - ) |
451 | | - |
452 | | - if "perf/train_wait_time" in log_dict and "perf/train_time" in log_dict: |
453 | | - total_time = log_dict["perf/train_wait_time"] + log_dict["perf/train_time"] |
454 | | - if total_time > 0: |
455 | | - log_dict["perf/total_train_time"] = total_time |
456 | | - log_dict["perf/wait_time_ratio"] = log_dict["perf/train_wait_time"] / total_time |
457 | | - |
458 | | - print(f"perf {rollout_id}: {log_dict}") |
459 | | - |
460 | | - step = ( |
461 | | - rollout_id |
462 | | - if not args.wandb_always_use_train_step |
463 | | - else rollout_id * args.rollout_batch_size * args.n_samples_per_prompt // args.global_batch_size |
464 | | - ) |
465 | | - if args.use_wandb: |
466 | | - log_dict["rollout/step"] = step |
467 | | - wandb.log(log_dict) |
468 | | - |
469 | | - if args.use_tensorboard: |
470 | | - from slime.utils.tensorboard_utils import _TensorboardAdapter |
471 | | - |
472 | | - tb = _TensorboardAdapter(args) |
473 | | - tb.log(data=log_dict, step=step) |
474 | | - timer_instance.reset() |
| 422 | + train_metric_utils.log_perf_data_raw( |
| 423 | + rollout_id=rollout_id, |
| 424 | + args=args, |
| 425 | + is_primary_rank=( |
| 426 | + mpu.get_tensor_model_parallel_rank() == 0 |
| 427 | + and mpu.is_pipeline_last_stage() |
| 428 | + and mpu.get_data_parallel_rank(with_context_parallel=True) == 0 |
| 429 | + ), |
| 430 | + compute_total_fwd_flops=lambda seq_lens: calculate_fwd_flops(seqlens=seq_lens, args=args) |
| 431 | + / dist.get_world_size() |
| 432 | + / 1e12, |
| 433 | + ) |
475 | 434 |
|
476 | 435 |
|
477 | 436 | def sync_actor_critic_data( |
|
0 commit comments