Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 65 additions & 52 deletions megatron/rl/rl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,21 +493,22 @@ def get_environment_rollouts(
nvtx_range = get_nvtx_range()

if args.rl_offload_optimizer_during_inference:
with nvtx_range("offload-optimizer-state-and-grad-buffers-during-inference"):
with nvtx_range("rl/offload-optimizer-before-inference", time=True):
if not args.rl_training_cuda_graphs:
model[0].offload_grad_buffers()
with nvtx_range("rl/offload/grad-buffers", time=True):
model[0].offload_grad_buffers()
else:
logger.warning(
"Gradient buffers will not be offloaded when training cudagraphs are used!"
)
optimizer.offload_to_cpu()
"Gradient buffers will not be offloaded when training cudagraphs are enabled!")
with nvtx_range("rl/offload/optimizer-state", time=True):
optimizer.offload_to_cpu()

# If we have separate training and inference models we to refit weights from the training model to the inference model.
has_separate_inference_model = inference_model is not None
if has_separate_inference_model:
# If the separate inference model weights were prefetched to CPU while idle, bring them
# back to GPU before refit/copy and before any CUDA-graph'd inference.
with nvtx_range("prefetch-inference-model-weights-to-gpu"):
with nvtx_range("rl/prefetch-weights-to-gpu", time=True):
inf_core = unwrap_model(inference_model[0])
_maybe_prefetch_separate_inference_model_weights(inf_core, to_cpu=False)
swap_model_weights(model, inference_model, args.refit_method)
Expand All @@ -525,7 +526,7 @@ def get_environment_rollouts(
pg_size = get_pg_size(inference_pg_collection.ep)
assert (n_prompts % pg_size == 0), f"{n_prompts=} must be divisible by {pg_size=}"

with nvtx_range("rollout-collection"):
with nvtx_range("rl/rollout-collection", time=True):
loop = get_asyncio_loop()
with megatron_rl_inference_mode(
inference_model,
Expand All @@ -536,15 +537,15 @@ def get_environment_rollouts(
increment_staleness_on_suspend=True,
) as inference_interface:

with nvtx_range("inference-setup"):
with nvtx_range("rl/inference-setup", time=True):
# Asyncronously run inference and rollout collection
rollout_generator = get_rollout_generator(
args, inference_interface, n_prompts, samples_per_group
)

# NOTE(jbarker): we need to double check this when using PP>1
rank = torch.distributed.get_rank()
with nvtx_range("collect-rollouts"):
with nvtx_range("rl/collect-rollouts", time=True):
if rank == 0:
log_single_rank(
logger,
Expand All @@ -569,16 +570,18 @@ def get_environment_rollouts(
# Just set up space to collect the rollouts
rollouts = [[None for _ in range(samples_per_group)] for _ in range(n_prompts)]

with nvtx_range("sync-rollouts"):
with nvtx_range("rl/sync-rollouts", time=True):
# Wait for Rollouts to be collected
# TODO(jbarker): double check why this isn't causing rank 0 memory allocations
torch.distributed.broadcast_object_list(rollouts, src=0)
logger.debug(f"Got rollouts on rank {rank}")

if args.rl_offload_optimizer_during_inference:
with nvtx_range("restore-optimizer-state-and-grad-buffers-after-inference"):
model[0].restore_grad_buffers()
optimizer.restore_from_cpu()
with nvtx_range("rl/restore-optimizer-after-inference", time=True):
with nvtx_range("rl/restore/grad-buffers", time=True):
model[0].restore_grad_buffers()
with nvtx_range("rl/restore/optimizer-state", time=True):
optimizer.restore_from_cpu()

if lang_rl_log_dir and rank == get_pg_rank(inference_pg_collection.tp):
with open(
Expand Down Expand Up @@ -679,8 +682,8 @@ def get_logprobs(model, tokens, position_ids, no_grad=False, sequence_packing=Fa

nvtx_range = get_nvtx_range()

with nvtx_range("get-logprobs", time=False):
with nvtx_range("forward-pass", time=False):
with nvtx_range("rl/get-logprobs", time=True):
with nvtx_range("rl/forward-pass", time=True):
# TODO(vitalyk): use fp16/bf16 as a function argument. Do not use args.

attention_mask_for_forward = None
Expand All @@ -707,7 +710,7 @@ def get_logprobs(model, tokens, position_ids, no_grad=False, sequence_packing=Fa
return logits_or_hidden_states
else:
logits = logits_or_hidden_states
with nvtx_range("log-softmax", time=False):
with nvtx_range("rl/log-softmax", time=True):
# We do not need logprobs for the n+1 token.
logprobs = selective_log_softmax(logits[:, :-1, :], tokens[:, 1:])
return logprobs
Expand Down Expand Up @@ -1304,8 +1307,8 @@ def prepare_data_for_update(
model = model[0]
dtype = torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32)

with nvtx_range("prepare-data-for-update"):
with nvtx_range("compute-group-stats"):
with nvtx_range("rl/prepare-data-for-update", time=True):
with nvtx_range("rl/compute-group-stats", time=True):
group_stats = compute_group_stats(rollouts, tokenizer, args.seq_length)
# TODO(vitalyk): why do we need global_advantages here? go inside packing
advantages = global_advantages = torch.tensor(group_stats.advantages, dtype=dtype).cuda()
Expand Down Expand Up @@ -1345,15 +1348,15 @@ def prepare_data_for_update(
# First we calculate them on a global level and then we split and recalculate on a local level.
# Sequence packing and reporting needs it global but non-packing wants it local.

with nvtx_range("prepare_trajectories"):
with nvtx_range("rl/prepare-trajectories", time=True):
trajs, generation_masks, inference_logprobs = prepare_trajectories(
rollouts, tokenizer, args.seq_length, sequence_packing, args.rl_skip_bos_token
)

packing_context = None
# Build trajectories based on sequence packing or standard processing
if sequence_packing:
with nvtx_range("sequence_packing", time=True):
if args.rl_use_sequence_packing:
with nvtx_range("rl/sequence-packing", time=True):
runtime_state.packing_context = packing_context = pack_all_trajectories(
trajs,
generation_masks,
Expand All @@ -1373,7 +1376,7 @@ def prepare_data_for_update(
logprobs_batch_size = 1
else:
# Always compute standard masks for the original data (we'll need them later)
with nvtx_range("get_ltor_masks_and_position_ids"):
with nvtx_range("rl/get-ltor-masks", time=True):
_, original_loss_mask, original_position_ids = get_ltor_masks_and_position_ids(
trajs,
tokenizer.eod,
Expand All @@ -1392,7 +1395,8 @@ def prepare_data_for_update(
)
logprobs_batch_size = args.micro_batch_size

with torch.no_grad(), nvtx_range("compute_logprobs", time=True):

with torch.no_grad(), nvtx_range("rl/compute-logprobs", time=True):
# Before we can update the model, we need to get the logprobs for the \pi_{old} model.

forward_backward_func = get_forward_backward_func()
Expand All @@ -1408,7 +1412,7 @@ def prepare_data_for_update(
pg_collection = get_attr_wrapped_model(model, "pg_collection")
pp_group = pg_collection.pp

with torch.no_grad(), nvtx_range("compute_old_logprobs", time=True):
with torch.no_grad(), nvtx_range("rl/compute-old-logprobs", time=True):
old_logprobs = compute_logprobs_batch(
model=model,
data_loader=data_loader,
Expand All @@ -1423,7 +1427,7 @@ def prepare_data_for_update(
is_correction=args.rl_inference_logprobs_is_correction,
)

with torch.no_grad(), nvtx_range("compute_ref_logprobs", time=True):
with torch.no_grad(), nvtx_range("rl/compute-ref-logprobs", time=True):
# We need to load the ref model state dict and compute the logprobs for the ref model
cur_st_dict = {
k: (v.cpu() if v is not None else v) for k, v in model.state_dict().items()
Expand All @@ -1446,13 +1450,14 @@ def prepare_data_for_update(
# logprobs are [b, seq, h] now.
model.load_state_dict(cur_st_dict)

torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
with nvtx_range("rl/synchronize-cuda-and-collect-garbage", time=True):
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()


if sequence_packing:
with nvtx_range("pack_logprobs", time=True):
with nvtx_range("rl/pack-logprobs", time=True):
# Store logprobs on gpu in packing context
# Since PackingContext is a dataclass, we add these as new attributes
packing_context.old_logprobs = old_logprobs.cuda()
Expand Down Expand Up @@ -1480,7 +1485,7 @@ def prepare_data_for_update(
packing_context.packed_inference_logprobs = packed_inference_logprobs.cuda()
# Only mark as having inference logprobs for IS correction if enabled
packing_context.has_inference_logprobs = args.rl_inference_logprobs_is_correction
with nvtx_range("create_dataloader"):
with nvtx_range("rl/create-dataloader", time=True):
# @vitalyk: This function also reconfigures the data loader to count the
# global_batch_size in the bins frame of reference.
# I think it will be a better design if we split the data loader creating and logic
Expand All @@ -1497,7 +1502,7 @@ def prepare_data_for_update(
)
loader = get_microbatch_dataloader(len(packing_context.packed_trajs), args.micro_batch_size)
else:
with nvtx_range("align_inference_logprobs", time=True):
with nvtx_range("rl/align-inference-logprobs", time=True):
if inference_logprobs is not None:
inference_logprobs = align_unpacked_inference_logprobs(
inference_logprobs=inference_logprobs,
Expand All @@ -1510,7 +1515,7 @@ def prepare_data_for_update(
# Nullify logprobs if not used in IS correction,
if not args.rl_inference_logprobs_is_correction:
inference_logprobs = None
with nvtx_range("create_dataloader"):
with nvtx_range("rl/create-dataloader", time=True):
# Because of multiturn, our batch sizes for non-sequence packed trajectories are not fixed anymore.
# As in sequence packing above, we need to reconfigure it too.
runtime_state.packing_context = None
Expand Down Expand Up @@ -1539,6 +1544,15 @@ def prepare_data_for_update(
data = TensorDataset(*dataset_tensors)
loader = DataLoader(data, batch_size=args.micro_batch_size)

with nvtx_range("rl/log-wandb-tb", time=True):
maybe_log_training_metrics(
group_stats=group_stats,
current_iteration=args.curr_iteration,
tokenizer=tokenizer,
example_group=example_group,
wandb_writer=wandb_writer,
tb_writer=tb_writer,
)

return RerunDataIterator(itertools.cycle(loader)), group_stats, example_groups

Expand Down Expand Up @@ -1667,9 +1681,10 @@ def evaluate_and_print_results_rl(
'top_k': args.rl_default_top_k,
},
)
evaluation_responses = loop.run_until_complete(agent.run_evaluation(request))
if not isinstance(evaluation_responses, list):
evaluation_responses = [evaluation_responses]
with get_nvtx_range()("rl/run-evaluation", time=True):
evaluation_responses = loop.run_until_complete(agent.run_evaluation(request))
if not isinstance(evaluation_responses, list):
evaluation_responses = [evaluation_responses]
else:
evaluation_responses = None

Expand Down Expand Up @@ -1866,7 +1881,7 @@ def megatron_rl_inference_mode(
# If this is a separate RL inference model with offloading enabled, ensure weights are on GPU
# before any CUDA-graph capture/replay or inference. This is a no-op if already on GPU.
model_core = unwrap_model(model[0])
with nvtx_range("prefetch-inference-model-weights-to-gpu"):
with nvtx_range("rl/prefetch-weights-to-gpu", time=True):
_maybe_prefetch_separate_inference_model_weights(model_core, to_cpu=False)

rotary_module = getattr(lang_module, "rotary_pos_emb", None)
Expand All @@ -1879,17 +1894,16 @@ def megatron_rl_inference_mode(
with torch.no_grad():

if offload_optimizer_during_inference:
with nvtx_range("offload-optimizer-state-and-grad-buffers-before-inference"):
with nvtx_range("rl/offload-optimizer-before-inference", time=True):
if not args.rl_training_cuda_graphs:
# Offload grad buffers from the training model (if separate inference model is used)
# or from the inference model (if they're the same model)
model_for_grad_offload = training_model if training_model is not None else model
model_for_grad_offload[0].offload_grad_buffers()
with nvtx_range("rl/offload/grad-buffers", time=True):
model_for_grad_offload = training_model if training_model is not None else model
model_for_grad_offload[0].offload_grad_buffers()
else:
logger.warning(
"Gradient buffers will not be offloaded when training cudagraphs are used!"
)
optimizer.offload_to_cpu()
"Gradient buffers will not be offloaded when training cudagraphs are enabled!")
with nvtx_range("rl/offload/optimizer-state", time=True):
optimizer.offload_to_cpu()

if cuda_graph_impl != "none" and not args.rl_training_cuda_graphs:
toggle_cuda_graphs(lang_module, cuda_graph_impl)
Expand All @@ -1900,7 +1914,7 @@ def megatron_rl_inference_mode(
logger.debug(f"[{dist.get_rank()}] Entered inference mode")
yield inference_interface

with nvtx_range("suspend-engine"):
with nvtx_range("rl/suspend-engine", time=True):
loop.run_until_complete(inference_interface.suspend())
if increment_staleness_on_suspend:
inference_interface.increment_staleness()
Expand Down Expand Up @@ -1930,12 +1944,12 @@ def megatron_rl_inference_mode(
_maybe_prefetch_separate_inference_model_weights(model_core, to_cpu=True)

if offload_optimizer_during_inference:
with nvtx_range("onload-optimizer-state-and-grad-buffers-after-inference"):
# Restore grad buffers to the training model (if separate inference model is used)
# or to the inference model (if they're the same model)
model_for_grad_offload = training_model if training_model is not None else model
model_for_grad_offload[0].restore_grad_buffers()
optimizer.restore_from_cpu()
with nvtx_range("rl/onload-optimizer-after-inference", time=True):
with nvtx_range("rl/onload/grad-buffers", time=True):
model_for_grad_offload = training_model if training_model is not None else model
model_for_grad_offload[0].restore_grad_buffers()
with nvtx_range("rl/onload/optimizer-state", time=True):
optimizer.restore_from_cpu()

# Set training model back to train mode (not inference model if they're separate)
training_lang_module = unwrap_model(training_model[0]) if training_model is not None else lang_module
Expand Down Expand Up @@ -2007,4 +2021,3 @@ def _pad_nonnull_with_zeros(data: list[Optional[torch.Tensor]], max_len: int) ->
# Create zero tensor for None logprobs
padded_data.append(torch.zeros(max_len))
return torch.stack(padded_data)

24 changes: 12 additions & 12 deletions megatron/rl/sequence_packing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
@dataclass
class PackingInfo:
"""Information about how sequences are packed into bins.

Attributes:
bin_seq_indices: List where each element contains the global sequence indices in that bin
seq_starts: Dict mapping bin index to list of start positions for each sequence in that bin
Expand All @@ -42,7 +42,7 @@ class PackingInfo:
@dataclass
class PackingContext:
"""Context containing all information needed for sequence packing during training.

Attributes:
bin_size: Maximum size of each bin (in tokens)
packer: 'SequencePacker' instance used for packing
Expand Down Expand Up @@ -93,11 +93,11 @@ def load_packed_data_by_index(bin_idx: int, packing_context: PackingContext, log
old_logprobs = getattr(packing_context, 'old_logprobs', None)
if old_logprobs is not None:
old_logprobs = old_logprobs[idx]

ref_logprobs = getattr(packing_context, 'ref_logprobs', None)
if ref_logprobs is not None:
ref_logprobs = ref_logprobs[idx]

# Slice from position 1 because logprobs predict the next token, so they are
# shifted by 1 relative to the input tokens (logprobs has shape [batch, seq_len-1])
loss_mask = packing_context.packed_loss_mask[idx, 1:]
Expand Down Expand Up @@ -403,7 +403,7 @@ def get_default_packed_seq_params(seq_length: int, max_sequences_per_bin: int, d
means no actual packing boundaries

Args:
seq_length: The sequence length
seq_length: The sequence length
max_sequences_per_bin: Max sequences to pack in a bin.
device: Device to create tensors on.

Expand Down Expand Up @@ -976,19 +976,19 @@ def pack_all_trajectories(trajs, generation_masks, inference_logprobs, global_ad
data_parallel_group = mpu.get_data_parallel_group()
nvtx_range = get_nvtx_range()

with nvtx_range("regather_trajectories", time=True):
with nvtx_range("rl/regather-trajectories", time=True):
def _gather(data):
data = data.cuda()
data_list = [torch.empty_like(data) for _ in range(data_parallel_world_size)]
torch.distributed.all_gather(data_list, data, group=data_parallel_group)
return torch.cat(data_list, dim=0)

trajs = _gather(trajs)
generation_masks = _gather(generation_masks)
trajs = _gather(trajs)
generation_masks = _gather(generation_masks)
if inference_logprobs is not None:
inference_logprobs = _gather(inference_logprobs)

with nvtx_range("pack_sequences", time=True):
with nvtx_range("rl/pack-sequences", time=True):
# Create packer with max sequences per bin limit to prevent extreme imbalance
packer = SequencePacker(
bin_size=bin_size,
Expand Down Expand Up @@ -1068,9 +1068,9 @@ def update_microbatch_calculator(
samples_ratio_per_step: float,
num_bins_this_rank: int,
bin_seq_indices: List[List[int]],
global_batch_size: int,
rampup_batch_size: int,
micro_batch_size: int,
global_batch_size: int,
rampup_batch_size: int,
micro_batch_size: int,
decrease_batch_size_if_needed: bool,
):
"""Return a data loader with seqpacked indices with microbatches in bins frame of reference.
Expand Down
Loading
Loading