|
| 1 | +############################################################################### |
| 2 | +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. |
| 3 | +# |
| 4 | +# See LICENSE for license information. |
| 5 | +############################################################################### |
| 6 | + |
| 7 | +import torch |
| 8 | +from megatron.core.models.common.model_chunk_schedule_plan import ( |
| 9 | + TransformerLayerSchedulePlan, |
| 10 | +) |
| 11 | +from megatron.core.models.common.model_chunk_schedule_plan import ( |
| 12 | + TransformerModelChunkSchedulePlan as TransformerModelChunkSchedulePlanBase, |
| 13 | +) |
| 14 | +from megatron.core.pipeline_parallel.utils import get_comm_stream |
| 15 | + |
| 16 | +from primus.backends.megatron.core.pipeline_parallel.zerobubble.zbpp_utils import ( |
| 17 | + WeightGradStore, |
| 18 | +) |
| 19 | + |
| 20 | + |
| 21 | +class TransformerModelChunkSchedulePlan(TransformerModelChunkSchedulePlanBase): |
| 22 | + |
| 23 | + @staticmethod |
| 24 | + def run( |
| 25 | + f_schedule_plan, |
| 26 | + b_schedule_plan, |
| 27 | + b_grad=None, |
| 28 | + pre_forward=None, |
| 29 | + pre_backward=None, |
| 30 | + post_forward=None, |
| 31 | + post_backward=None, |
| 32 | + ): |
| 33 | + """Model Chunk level 1f1b fine-grained scheduler. |
| 34 | +
|
| 35 | + This function schedules the forward and backward passes for a model chunk, |
| 36 | + which interleaves forward and backward function of multiple Transformer layers |
| 37 | + within a model chunk, and this is needed to overlap the submodules between the individual |
| 38 | + forward and backward functions. |
| 39 | +
|
| 40 | + Assume there are 4 layers in the given model chunk: |
| 41 | + Phase 0: p2p_comm_sync -> forward_preprocess -> p2p_comm_sync -> backward_postprocess |
| 42 | + Phase 1: forward_layer[0] + backward_layer[3], overlapped execution by schedule_layer_1f1b |
| 43 | + Phase 2: forward_layer[1] + backward_layer[2], overlapped execution by schedule_layer_1f1b |
| 44 | + Phase 3: forward_layer[2] + backward_layer[1], overlapped execution by schedule_layer_1f1b |
| 45 | + Phase 4: forward_layer[3] + backward_layer[0], overlapped execution by schedule_layer_1f1b |
| 46 | + Phase 5: send_forward_recv_backward -> send_backward_recv_forward |
| 47 | + Phase 6: backward_dw of the first layer -> forward_postprocess -> backward_preprocess |
| 48 | +
|
| 49 | + Args: |
| 50 | + f_schedule_plan (TransformerModelChunkSchedulePlan): The forward schedule plan |
| 51 | + b_schedule_plan (TransformerModelChunkSchedulePlan): The backward schedule plan |
| 52 | + b_grad (Tensor or None): The gradient of the loss function |
| 53 | + pre_forward (callable or None): The function to call before the forward pass |
| 54 | + pre_backward (callable or None): The function to call before the backward pass |
| 55 | + post_forward (callable or None): The function to call after the forward pass |
| 56 | + post_backward (callable or None): The function to call after the backward pass |
| 57 | + Returns: |
| 58 | + The output of the forward pass. |
| 59 | + """ |
| 60 | + f_input = None |
| 61 | + if f_schedule_plan: |
| 62 | + # pp output send/receive sync |
| 63 | + if pre_forward is not None: |
| 64 | + pre_forward(f_schedule_plan.vp_stage) |
| 65 | + f_schedule_plan.record_current_stream() |
| 66 | + f_input = f_schedule_plan.pre_process.forward() |
| 67 | + |
| 68 | + if b_schedule_plan: |
| 69 | + b_schedule_plan.record_current_stream() |
| 70 | + assert b_grad is not None |
| 71 | + if pre_backward is not None: |
| 72 | + pre_backward(b_schedule_plan.vp_stage) |
| 73 | + b_schedule_plan.record_current_stream() |
| 74 | + |
| 75 | + if b_schedule_plan.post_process is not None: |
| 76 | + b_grad = b_schedule_plan.post_process.backward(b_grad) |
| 77 | + |
| 78 | + f_num_layers = f_schedule_plan.num_layers() if f_schedule_plan is not None else 0 |
| 79 | + b_num_layers = b_schedule_plan.num_layers() if b_schedule_plan is not None else 0 |
| 80 | + overlapped_layers = min(f_num_layers, b_num_layers) |
| 81 | + |
| 82 | + # combined forward and backward pass for overlapped layers |
| 83 | + for i in range(overlapped_layers): |
| 84 | + f_layer = f_schedule_plan.get_layer(i) |
| 85 | + b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i) |
| 86 | + torch.cuda.nvtx.range_push(f"layer_{i}f-layer_{b_num_layers - 1 - i}b") |
| 87 | + f_input, b_grad = TransformerLayerSchedulePlan.run( |
| 88 | + f_layer, |
| 89 | + b_layer, |
| 90 | + f_input=f_input, |
| 91 | + b_grad=b_grad, |
| 92 | + is_last_layer_in_bwd=(i == b_num_layers - 1), |
| 93 | + ) |
| 94 | + torch.cuda.nvtx.range_pop() |
| 95 | + |
| 96 | + # backward pass for the remaining layers |
| 97 | + for i in range(overlapped_layers, b_num_layers): |
| 98 | + b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i) |
| 99 | + torch.cuda.nvtx.range_push(f"layer_{b_num_layers - 1 - i}b") |
| 100 | + _, b_grad = TransformerLayerSchedulePlan.run( |
| 101 | + None, b_layer, b_grad=b_grad, is_last_layer_in_bwd=(i == b_num_layers - 1) |
| 102 | + ) |
| 103 | + torch.cuda.nvtx.range_pop() |
| 104 | + |
| 105 | + # forward pass for the remaining layers |
| 106 | + for i in range(overlapped_layers, f_num_layers): |
| 107 | + f_layer = f_schedule_plan.get_layer(i) |
| 108 | + torch.cuda.nvtx.range_push(f"layer_{i}f") |
| 109 | + f_input, _ = TransformerLayerSchedulePlan.run(f_layer, None, f_input=f_input) |
| 110 | + torch.cuda.nvtx.range_pop() |
| 111 | + |
| 112 | + if f_schedule_plan is not None and post_forward is not None: |
| 113 | + # post_forward()/send_forward_recv_forward() is running in the communication stream, |
| 114 | + # so the p2p comm could be overlapped with the attn backward |
| 115 | + with torch.cuda.stream(get_comm_stream()): |
| 116 | + f_schedule_plan.wait_current_stream() |
| 117 | + post_forward(f_input, f_schedule_plan.vp_stage) |
| 118 | + |
| 119 | + # post_backward()/send_backward_recv_backward() is running in the computation stream, |
| 120 | + # so the p2p comm could be overlapped with the wgrad of attn backward |
| 121 | + if b_schedule_plan is not None and post_backward is not None: |
| 122 | + b_schedule_plan.wait_current_stream() |
| 123 | + post_backward(b_grad, b_schedule_plan.vp_stage) |
| 124 | + |
| 125 | + # Delay the dw in backward pass for overlapping with the p2p comm |
| 126 | + if b_num_layers > 0: |
| 127 | + WeightGradStore.flush() |
| 128 | + WeightGradStore.pop() |
| 129 | + |
| 130 | + # post process forward |
| 131 | + if f_schedule_plan is not None and f_schedule_plan.post_process is not None: |
| 132 | + f_input = f_schedule_plan.post_process.forward(f_input) |
| 133 | + # pre process backward |
| 134 | + if b_schedule_plan is not None: |
| 135 | + b_schedule_plan.pre_process.backward(b_grad) |
| 136 | + |
| 137 | + if f_schedule_plan: |
| 138 | + f_schedule_plan.wait_current_stream() |
| 139 | + if b_schedule_plan: |
| 140 | + b_schedule_plan.wait_current_stream() |
| 141 | + |
| 142 | + # Release reference as early as possible, this helps avoid memory leak. |
| 143 | + if b_schedule_plan is not None: |
| 144 | + b_schedule_plan.release_state() |
| 145 | + |
| 146 | + return f_input |
0 commit comments