Skip to content

Commit 069dd8b

Browse files
timmoon10pre-commit-ci[bot]
authored andcommitted
[PyTorch] Debug NCCL communication overlapping in linear backward with FP8 data (#1620)
* Overlap input all-gather with dgrad GEMM in FP8 linear layers Signed-off-by: Tim Moon <[email protected]> * Add missing docstring Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f546444 commit 069dd8b

File tree

6 files changed

+166
-97
lines changed

6 files changed

+166
-97
lines changed

transformer_engine/pytorch/distributed.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
2020
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
2121

22-
from .utils import safely_set_viewless_tensor_data
22+
from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data
2323
from .constants import dist_group_type
2424
from .fp8 import FP8GlobalStateManager, fp8_autocast
2525
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
@@ -860,23 +860,29 @@ def _all_gather_fp8(
860860
process_group: dist_group_type,
861861
*,
862862
async_op: bool = False,
863-
quantizer: Optional[Float8Quantizer] = None,
863+
quantizer: Optional[Quantizer] = None,
864864
out_shape: Optional[list[int]] = None,
865865
) -> tuple[Float8TensorBase, Optional[torch.distributed.Work]]:
866866
"""All-gather FP8 tensor along first dimension."""
867867
world_size = get_distributed_world_size(process_group)
868868

869+
# Check that quantizer is valid
870+
if quantizer is not None and not isinstance(
871+
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
872+
):
873+
raise ValueError(f"Got non-FP8 quantizer ({quantizer.__class__.__name__})")
874+
869875
# Output tensor dims
870876
if out_shape is None:
871877
out_shape = list(inp.size())
872878
out_shape[0] *= world_size
873879

874-
# Quantize input tensor if needed
880+
# Cast input tensor to FP8 if needed
881+
# Note: We cannot directly all-gather the transposed FP8 tensor,
882+
# so temporarily modify quantizer to avoid creating FP8 transpose.
875883
if not isinstance(inp, Float8TensorBase):
876-
assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer))
877-
# we cannot directly gather the transposed fp8 tensor
878-
# so we need to disable columnwise usage for the quantizer
879-
# and then set it back to the original value after quantizing
884+
if quantizer is None:
885+
raise ValueError("Input tensor is not FP8 and no quantizer was provided")
880886
init_rowwise_usage = quantizer.rowwise_usage
881887
init_columnwise_usage = quantizer.columnwise_usage
882888
quantizer.set_usage(rowwise=True, columnwise=False)
@@ -888,7 +894,7 @@ def _all_gather_fp8(
888894

889895
# Construct output tensor
890896
out: Float8TensorBase
891-
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
897+
if quantizer is not None:
892898
dtype = torch.float32
893899
device = "cuda"
894900
if isinstance(inp, Float8Tensor):
@@ -906,9 +912,8 @@ def _all_gather_fp8(
906912
out._transpose_invalid = True
907913
else:
908914
raise RuntimeError("FP8TensorBase is not supported yet without Quantizer")
909-
# For delayed scaling, scale_inv is from history, so we can pass it from inp to out
910-
# For current scaling, scale_inv is from doing amax reduction in C++ code, so each rank should have same scale_inv,
911-
# so we can just pass it from inp to out
915+
916+
# Assume scaling factors are identical across ranks
912917
out._scale_inv = inp._scale_inv
913918

914919
# Perform communication
@@ -920,12 +925,13 @@ def _all_gather_fp8(
920925
)
921926

922927
# Make sure FP8 transpose is populated if needed
923-
if out._transpose is not None:
928+
needs_transpose = (
929+
quantizer is not None and quantizer.columnwise_usage and not non_tn_fp8_gemm_supported()
930+
)
931+
if needs_transpose:
924932
if handle is not None:
925933
handle.wait()
926934
handle = None
927-
if not isinstance(out, Float8Tensor):
928-
raise RuntimeError("FP8TensorBase does not support FP8 transpose yet")
929935
out._create_transpose()
930936

931937
return out, handle

transformer_engine/pytorch/module/layernorm_linear.py

+48-27
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
prepare_for_saving,
5656
restore_from_saved,
5757
)
58+
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
5859
from ..tensor.mxfp8_tensor import MXFP8Quantizer
5960
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
6061
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
@@ -557,12 +558,27 @@ def backward(
557558
ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
558559
dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer)
559560

561+
# Configure quantizer for grad output tensor
562+
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
563+
# requires column-wise usage
560564
if ctx.grad_output_quantizer is not None:
561-
# Reduce duplicated transpose, which is performed in grad_output.update_usage
562-
if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling():
563-
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=False)
564-
else:
565-
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
565+
rowwise_usage = True
566+
columnwise_usage = True
567+
if ctx.ub_overlap_ag and isinstance(
568+
ctx.grad_output_quantizer,
569+
(Float8Quantizer, Float8CurrentScalingQuantizer),
570+
):
571+
# If data is in FP8 and communication is handled
572+
# with Userbuffers, we compute FP8 transposes
573+
# manually
574+
columnwise_usage = False
575+
ctx.grad_output_quantizer.set_usage(
576+
rowwise=rowwise_usage,
577+
columnwise=columnwise_usage,
578+
)
579+
580+
# Prepare grad output tensor
581+
# Note: Cast to expected dtype and perform tensor-parallel communication
566582
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
567583
(
568584
grad_output,
@@ -575,15 +591,19 @@ def backward(
575591
)
576592
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
577593

578-
# Prepare GEMM input
579-
# Note: Perform tensor-parallel communication if needed
594+
# Launch tensor-parallel communication for LayerNorm out tensor
580595
ln_out_total = None
581596
ln_out_total_work = None
582597
if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad:
583598
quantizer = None
584599
if ctx.fp8:
585600
quantizer = ctx.input_quantizer
586-
quantizer.set_usage(rowwise=False, columnwise=True)
601+
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
602+
# If data is in FP8, we compute FP8 transposes manually
603+
quantizer.set_usage(rowwise=True, columnwise=False)
604+
else:
605+
# wgrad GEMM requires input with column-wise usage
606+
quantizer.set_usage(rowwise=False, columnwise=True)
587607
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
588608
ln_out_total, ln_out_total_work = gather_along_first_dim(
589609
ln_out,
@@ -652,6 +672,8 @@ def backward(
652672
# Compute grad weight tensor
653673
wgrad = None
654674
if ctx.requires_wgrad:
675+
676+
# Synchronize tensor-parallel communication for input tensor
655677
if ctx.ub_bulk_dgrad:
656678
ln_out_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer)
657679
if ctx.fp8:
@@ -665,18 +687,25 @@ def backward(
665687
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
666688
# have a valid transpose.
667689
ln_out_total._create_transpose()
690+
if ln_out_total_work is not None:
691+
ln_out_total_work.wait()
692+
ln_out_total_work = None
668693

669-
else:
670-
if ln_out_total_work is not None:
671-
# Synchronize tensor-parallel communication
672-
ln_out_total_work.wait()
673-
ln_out_total_work = None
674-
694+
# Make sure GEMM inputs have required data
695+
if isinstance(ln_out_total, QuantizedTensor):
696+
ln_out_total.update_usage(columnwise_usage=True)
675697
if isinstance(grad_output, QuantizedTensor):
676-
# This is a no-op if platform supports non-TN FP8 GEMM or the transpose
677-
# already exists.
678-
grad_output.update_usage(rowwise_usage=True, columnwise_usage=True)
698+
grad_output.update_usage(columnwise_usage=True)
699+
700+
# Figure out whether to use split accumulator
701+
use_split_accumulator = _2X_ACC_WGRAD
702+
if ctx.fp8:
703+
recipe = ctx.fp8_recipe
704+
if hasattr(recipe, "fp8_gemm_wgrad"):
705+
use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
679706

707+
# Output buffer for overlapping grad input
708+
# reduce-scatter with wgrad GEMM
680709
if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
681710
rs_out = torch.empty(
682711
dgrad_shape, dtype=ctx.activation_dtype, device=inputmat.device
@@ -685,14 +714,6 @@ def backward(
685714
# wgrad GEMM
686715
# Note: Fuse with bgrad computation if needed
687716
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
688-
wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD
689-
if ctx.fp8:
690-
recipe = ctx.fp8_recipe
691-
if hasattr(recipe, "fp8_gemm_wgrad"):
692-
wgrad_gemm_use_split_accumulator = (
693-
recipe.fp8_gemm_wgrad.use_split_accumulator
694-
)
695-
696717
wgrad, grad_bias_, *_, rs_out = general_gemm(
697718
ln_out_total,
698719
grad_output,
@@ -704,7 +725,7 @@ def backward(
704725
),
705726
bias=(bias if (grad_bias is None and not ctx.fp8) else None),
706727
out=main_grad if ctx.fuse_wgrad_accumulation else None,
707-
use_split_accumulator=wgrad_gemm_use_split_accumulator,
728+
use_split_accumulator=use_split_accumulator,
708729
accumulate=accumulate_wgrad_into_param_main_grad,
709730
ub=ub_obj_wgrad,
710731
ub_type=ub_type_wgrad,
@@ -728,7 +749,7 @@ def backward(
728749
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
729750
clear_tensor_data(ln_out_total)
730751

731-
# Synchronize tensor parallel communication
752+
# Make sure all tensor-parallel communication is finished
732753
if ln_out_total_work is not None:
733754
ln_out_total_work.wait()
734755
ln_out_total_work = None

transformer_engine/pytorch/module/layernorm_mlp.py

+43-21
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@
5656
from ..constants import dist_group_type
5757
from ..jit import no_torch_dynamo
5858
from ..graph import is_graph_capturing
59-
from ..tensor.float8_tensor import Float8Tensor
59+
from ..tensor.float8_tensor import (
60+
Float8CurrentScalingQuantizer,
61+
Float8Quantizer,
62+
Float8Tensor,
63+
)
6064
from ..tensor.mxfp8_tensor import MXFP8Quantizer
6165
from ._common import apply_normalization, _fix_gathered_fp8_transpose
6266
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
@@ -642,15 +646,27 @@ def backward(
642646
ctx.ub_bulk_dgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_dgrad
643647
ctx.ub_bulk_wgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_wgrad
644648

645-
# Prepare grad output tensor
646-
# Note: Cast to expected dtype and perform tensor-parallel communication
649+
# Configure quantizer for FC2 grad output tensor
650+
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
651+
# requires column-wise usage
647652
if ctx.grad_fc2_output_quantizer is not None:
648-
# Reduce duplicated transpose, which is performed in grad_output.update_usage
649-
if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling():
650-
ctx.grad_fc2_output_quantizer.set_usage(rowwise=True, columnwise=False)
651-
else:
652-
ctx.grad_fc2_output_quantizer.set_usage(rowwise=True, columnwise=True)
653+
rowwise_usage = True
654+
columnwise_usage = True
655+
if ctx.ub_overlap_ag and isinstance(
656+
ctx.grad_fc2_output_quantizer,
657+
(Float8Quantizer, Float8CurrentScalingQuantizer),
658+
):
659+
# If data is in FP8 and communication is handled
660+
# with Userbuffers, we compute FP8 transposes
661+
# manually
662+
columnwise_usage = False
663+
ctx.grad_fc2_output_quantizer.set_usage(
664+
rowwise=rowwise_usage,
665+
columnwise=columnwise_usage,
666+
)
653667

668+
# Prepare FC2 grad output tensor
669+
# Note: Cast to expected dtype and perform tensor-parallel communication
654670
ub_obj_fc2_dgrad = None
655671
if ctx.ub_overlap_ag:
656672
ub_obj_fc2_dgrad = get_ub("fc2_dgrad")
@@ -662,8 +678,7 @@ def backward(
662678
ctx, grad_outputs[0], True, ctx.grad_fc2_output_quantizer
663679
)
664680

665-
# Prepare FC1 GEMM input
666-
# Note: Perform tensor-parallel communication if needed
681+
# Launch tensor-parallel communication for FC1 GEMM input
667682
ln_out_total = None
668683
ln_out_total_work = None
669684
if (
@@ -675,7 +690,12 @@ def backward(
675690
quantizer = None
676691
if ctx.fp8:
677692
quantizer = ctx.fc1_input_quantizer
678-
quantizer.set_usage(rowwise=False, columnwise=True)
693+
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
694+
# If data is in FP8, we compute FP8 transposes manually
695+
quantizer.set_usage(rowwise=True, columnwise=False)
696+
else:
697+
# wgrad GEMM requires input with column-wise usage
698+
quantizer.set_usage(rowwise=False, columnwise=True)
679699
ln_out_total, ln_out_total_work = gather_along_first_dim(
680700
ln_out,
681701
ctx.tp_group,
@@ -868,6 +888,8 @@ def backward(
868888
# FC1 WGRAD
869889
fc1_wgrad = None
870890
if ctx.fc1_weight_requires_grad:
891+
892+
# Synchronize tensor-parallel communication for FC1 GEMM input tensor
871893
if ctx.ub_bulk_dgrad:
872894
ln_out_total = ub_obj_fc1_dgrad.get_buffer(ctx.fc1_input_quantizer)
873895
if ctx.fp8:
@@ -879,24 +901,24 @@ def backward(
879901
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
880902
# have a valid transpose.
881903
ln_out_total._create_transpose()
904+
if ln_out_total_work is not None:
905+
ln_out_total_work.wait()
906+
ln_out_total_work = None
882907

883-
else:
884-
if ln_out_total_work is not None:
885-
# Synchronize tensor-parallel communication
886-
ln_out_total_work.wait()
887-
ln_out_total_work = None
888-
889-
# Make sure GEMM inputs have expected data
908+
# Make sure GEMM inputs have required data
890909
if isinstance(ln_out_total, QuantizedTensor):
891-
ln_out_total.update_usage(rowwise_usage=True, columnwise_usage=True)
910+
ln_out_total.update_usage(columnwise_usage=True)
892911
if isinstance(dact, QuantizedTensor):
893-
dact.update_usage(rowwise_usage=True, columnwise_usage=True)
912+
dact.update_usage(columnwise_usage=True)
894913

914+
# Output buffer for overlapping grad input
915+
# reduce-scatter with wgrad GEMM
895916
if ctx.ub_bulk_wgrad and ub_obj_fc1_wgrad.is_fp8_ubuf():
896917
fc1_dgrad_rs_out = torch.empty(
897918
fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda"
898919
)
899920

921+
# wgrad GEMM
900922
fc1_wgrad_outputs = general_gemm(
901923
ln_out_total,
902924
dact,
@@ -930,7 +952,7 @@ def backward(
930952
else:
931953
fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(None, local_chunk=True)
932954

933-
# Synchronize tensor parallel communication
955+
# Make sure all tensor-parallel communication is finished
934956
if ln_out_total_work is not None:
935957
ln_out_total_work.wait()
936958
ln_out_total_work = None

0 commit comments

Comments
 (0)