Skip to content

Commit f92e40c

Browse files
authored
pp_warmup optimization (#185)
1 parent a86d271 commit f92e40c

File tree

4 files changed

+32
-24
lines changed

4 files changed

+32
-24
lines changed

primus/README_patch.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ These arguments are introduced in the Megatron module logic (e.g., training loop
4747
| `disable_last_saving` | `false` | v0.1.0 | Skip saving the final checkpoint at the last iteration. | NA | Useful for profiling or benchmarking runs. |
4848
| `no_fp8_weight_transpose_cache` | `false` | v0.2.0 | Disable the FP8 weight transpose cache to save memory. | `megatron.core.extensions.transformer_engine.TELinear`, `megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear`, `megatron.core.extensions.transformer_engine.TEDelayedScaling` | May affect performance but reduce memory use. |
4949
| `decoder_pipeline_manual_split_list` | `null` | v0.2.0 | Enable manual pipeline split in (interleaved) 1F1B pipeline parallelism. | `megatron.core.transformer.transformer_block.get_num_layers_to_build`, `megatron.core.transformer.transformer_layer.get_transformer_layer_offset` | May be deprecated when megatron gets updated. |
50-
| `attn_warmup` | `false` | v0.2.0 | Add attention fwd/bwd warmup to save iter1's time when pp is used. | NA | Can save much time for pipeline debug. |
50+
| `pp_warmup` | `false` | v0.2.0 | Add attention/mlp fwd/bwd warmup to save iter1's time when pp degree is large. | NA | Can save much time for pipeline debug. |
5151
| `dump_pp_data` | `false` | v0.2.0 | Enable dumping pp schedule data for visualization. | `megatron.core.pipeline_parallel.schedules.forward_step`, `megatron.core.pipeline_parallel.schedules.backward_step`, `megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving`, `megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_without_interleaving` | Useful for pipeline schedule visualization. |
5252
| `disable_profiler_activity_cpu` | `false` | v0.2.0 | Disable CPU activityt in torch profiling, . | NA | If you only want to trace CUDA kernels and get a smaller trace JSON file, you can enable this option. However, if you plan to run with TraceLen, please do not enable it. |
5353
| `use_rocm_mem_info` | `false` | v0.2.0 | Logging ROCm memory information in Megatron-LM Trainer | NA | If `use_rocm_mem_info = True`, ROCm memory information will be collected with `rocm-smi` at every iteration. |

primus/configs/modules/megatron/primus_megatron_module.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ no_fp8_weight_transpose_cache: false
2323
decoder_pipeline_manual_split_list: null # int list
2424

2525
# perf
26-
attn_warmup: false # set to true to decrease iter-1 time when using pp
26+
pp_warmup: false # set to true to decrease iter-1 time when using pp
2727

2828
# tool
2929
dump_pp_data: false

primus/modules/trainer/megatron/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,16 +1222,16 @@ def run(self, *args, **kwargs):
12221222
one_logger = get_one_logger()
12231223
args = get_args()
12241224

1225-
if args.attn_warmup:
1226-
from .utils import warmup_attn
1225+
if args.pp_warmup:
1226+
from .utils import pp_warmup
12271227

12281228
log_rank_0(
12291229
"warmup attn on each rank in parallel to decrease "
12301230
"the first iter time, especially when pp is used"
12311231
)
12321232
timers = get_timers()
12331233
timers("warmup-attn", log_level=0).start(barrier=True)
1234-
warmup_attn(args, self.config, self.model, self.optimizer)
1234+
pp_warmup(args, self.config, self.model, self.optimizer)
12351235
timers("warmup-attn").stop()
12361236
timers.log(["warmup-attn"], barrier=True)
12371237

primus/modules/trainer/megatron/utils.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -218,26 +218,34 @@ def get_transformer_layer_offset_patch(config, vp_stage):
218218
megatron.core.models.gpt.gpt_layer_specs.get_transformer_layer_offset = get_transformer_layer_offset_patch
219219

220220

221-
def warmup_attn(args, config, model, optimizer):
222-
if model[0].use_forward_hook:
223-
model[0].disable_forward_pre_hook()
224-
225-
attn = model[0].module.module.decoder.layers[0].self_attention
226-
warmup_input = torch.randn(args.seq_length, 1, config.hidden_size, device="cuda", dtype=torch.bfloat16)
227-
attention_mask = (
228-
torch.tril(torch.ones((args.seq_length, args.seq_length), device="cuda")).unsqueeze(0).unsqueeze(0)
229-
== 0
230-
)
231-
232-
warmup_output = attn(warmup_input, attention_mask=attention_mask)
233-
warmup_output[0].backward(torch.ones_like(warmup_output[0]))
234-
221+
def pp_warmup(args, config, model, optimizer):
235222
for model_chunk in model:
236-
model_chunk.zero_grad_buffer()
237-
optimizer.zero_grad()
238-
239-
if model[0].use_forward_hook:
240-
model[0].enable_forward_pre_hook()
223+
with model_chunk.no_sync():
224+
if model_chunk.use_forward_hook:
225+
model_chunk.disable_forward_pre_hook()
226+
dtype = torch.float32
227+
if config.bf16:
228+
dtype = torch.bfloat16
229+
elif config.fp16:
230+
dtype = torch.float16
231+
seq_len = args.seq_length // args.tensor_model_parallel_size // args.context_parallel_size
232+
233+
for layer in model_chunk.module.module.decoder.layers:
234+
attn_input = torch.randn(seq_len, 1, config.hidden_size, device="cuda", dtype=dtype)
235+
attention_mask = (
236+
torch.tril(torch.ones((seq_len, seq_len), device="cuda")).unsqueeze(0).unsqueeze(0) == 0
237+
)
238+
attn_output = layer.self_attention(attn_input, attention_mask=attention_mask)
239+
attn_output[0].backward(torch.ones_like(attn_output[0]))
240+
241+
mlp_input = torch.randn(seq_len, 1, config.hidden_size, device="cuda", dtype=dtype)
242+
mlp_output = layer.mlp(mlp_input)
243+
mlp_output[0].backward(torch.ones_like(mlp_output[0]))
244+
245+
if model_chunk.use_forward_hook:
246+
model_chunk.enable_forward_pre_hook()
247+
optimizer.zero_grad()
248+
torch.cuda.empty_cache()
241249

242250

243251
def schedule_wrapper(func):

0 commit comments

Comments
 (0)