55
55
prepare_for_saving ,
56
56
restore_from_saved ,
57
57
)
58
+ from ..tensor .float8_tensor import Float8CurrentScalingQuantizer , Float8Quantizer
58
59
from ..tensor .mxfp8_tensor import MXFP8Quantizer
59
60
from ..tensor ._internal .mxfp8_tensor_base import MXFP8TensorBase
60
61
from ..cpu_offload import is_cpu_offload_enabled , set_offloading_param
@@ -557,12 +558,27 @@ def backward(
557
558
ub_obj_wgrad .set_buffer_params (ctx .grad_input_quantizer )
558
559
dgrad_bulk = ub_obj_wgrad .get_buffer (ctx .grad_input_quantizer )
559
560
561
+ # Configure quantizer for grad output tensor
562
+ # Note: dgrad GEMM requires row-wise usage, wgrad GEMM
563
+ # requires column-wise usage
560
564
if ctx .grad_output_quantizer is not None :
561
- # Reduce duplicated transpose, which is performed in grad_output.update_usage
562
- if ctx .ub_overlap_ag and ctx .fp8_recipe .float8_per_tensor_scaling ():
563
- ctx .grad_output_quantizer .set_usage (rowwise = True , columnwise = False )
564
- else :
565
- ctx .grad_output_quantizer .set_usage (rowwise = True , columnwise = True )
565
+ rowwise_usage = True
566
+ columnwise_usage = True
567
+ if ctx .ub_overlap_ag and isinstance (
568
+ ctx .grad_output_quantizer ,
569
+ (Float8Quantizer , Float8CurrentScalingQuantizer ),
570
+ ):
571
+ # If data is in FP8 and communication is handled
572
+ # with Userbuffers, we compute FP8 transposes
573
+ # manually
574
+ columnwise_usage = False
575
+ ctx .grad_output_quantizer .set_usage (
576
+ rowwise = rowwise_usage ,
577
+ columnwise = columnwise_usage ,
578
+ )
579
+
580
+ # Prepare grad output tensor
581
+ # Note: Cast to expected dtype and perform tensor-parallel communication
566
582
nvtx_range_push (f"{ nvtx_label } .grad_output_preprocess" )
567
583
(
568
584
grad_output ,
@@ -575,15 +591,19 @@ def backward(
575
591
)
576
592
nvtx_range_pop (f"{ nvtx_label } .grad_output_preprocess" )
577
593
578
- # Prepare GEMM input
579
- # Note: Perform tensor-parallel communication if needed
594
+ # Launch tensor-parallel communication for LayerNorm out tensor
580
595
ln_out_total = None
581
596
ln_out_total_work = None
582
597
if ctx .ln_out_needs_gather and not ctx .ub_bulk_dgrad :
583
598
quantizer = None
584
599
if ctx .fp8 :
585
600
quantizer = ctx .input_quantizer
586
- quantizer .set_usage (rowwise = False , columnwise = True )
601
+ if isinstance (quantizer , (Float8Quantizer , Float8CurrentScalingQuantizer )):
602
+ # If data is in FP8, we compute FP8 transposes manually
603
+ quantizer .set_usage (rowwise = True , columnwise = False )
604
+ else :
605
+ # wgrad GEMM requires input with column-wise usage
606
+ quantizer .set_usage (rowwise = False , columnwise = True )
587
607
nvtx_range_push (f"{ nvtx_label } .column_parallel_comm_input" )
588
608
ln_out_total , ln_out_total_work = gather_along_first_dim (
589
609
ln_out ,
@@ -652,6 +672,8 @@ def backward(
652
672
# Compute grad weight tensor
653
673
wgrad = None
654
674
if ctx .requires_wgrad :
675
+
676
+ # Synchronize tensor-parallel communication for input tensor
655
677
if ctx .ub_bulk_dgrad :
656
678
ln_out_total = ub_obj_dgrad .get_buffer (ctx .input_quantizer )
657
679
if ctx .fp8 :
@@ -665,18 +687,25 @@ def backward(
665
687
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
666
688
# have a valid transpose.
667
689
ln_out_total ._create_transpose ()
690
+ if ln_out_total_work is not None :
691
+ ln_out_total_work .wait ()
692
+ ln_out_total_work = None
668
693
669
- else :
670
- if ln_out_total_work is not None :
671
- # Synchronize tensor-parallel communication
672
- ln_out_total_work .wait ()
673
- ln_out_total_work = None
674
-
694
+ # Make sure GEMM inputs have required data
695
+ if isinstance (ln_out_total , QuantizedTensor ):
696
+ ln_out_total .update_usage (columnwise_usage = True )
675
697
if isinstance (grad_output , QuantizedTensor ):
676
- # This is a no-op if platform supports non-TN FP8 GEMM or the transpose
677
- # already exists.
678
- grad_output .update_usage (rowwise_usage = True , columnwise_usage = True )
698
+ grad_output .update_usage (columnwise_usage = True )
699
+
700
+ # Figure out whether to use split accumulator
701
+ use_split_accumulator = _2X_ACC_WGRAD
702
+ if ctx .fp8 :
703
+ recipe = ctx .fp8_recipe
704
+ if hasattr (recipe , "fp8_gemm_wgrad" ):
705
+ use_split_accumulator = recipe .fp8_gemm_wgrad .use_split_accumulator
679
706
707
+ # Output buffer for overlapping grad input
708
+ # reduce-scatter with wgrad GEMM
680
709
if ctx .ub_bulk_wgrad and ub_obj_wgrad .is_fp8_ubuf ():
681
710
rs_out = torch .empty (
682
711
dgrad_shape , dtype = ctx .activation_dtype , device = inputmat .device
@@ -685,14 +714,6 @@ def backward(
685
714
# wgrad GEMM
686
715
# Note: Fuse with bgrad computation if needed
687
716
nvtx_range_push (f"{ nvtx_label } .wgrad_gemm" )
688
- wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD
689
- if ctx .fp8 :
690
- recipe = ctx .fp8_recipe
691
- if hasattr (recipe , "fp8_gemm_wgrad" ):
692
- wgrad_gemm_use_split_accumulator = (
693
- recipe .fp8_gemm_wgrad .use_split_accumulator
694
- )
695
-
696
717
wgrad , grad_bias_ , * _ , rs_out = general_gemm (
697
718
ln_out_total ,
698
719
grad_output ,
@@ -704,7 +725,7 @@ def backward(
704
725
),
705
726
bias = (bias if (grad_bias is None and not ctx .fp8 ) else None ),
706
727
out = main_grad if ctx .fuse_wgrad_accumulation else None ,
707
- use_split_accumulator = wgrad_gemm_use_split_accumulator ,
728
+ use_split_accumulator = use_split_accumulator ,
708
729
accumulate = accumulate_wgrad_into_param_main_grad ,
709
730
ub = ub_obj_wgrad ,
710
731
ub_type = ub_type_wgrad ,
@@ -728,7 +749,7 @@ def backward(
728
749
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
729
750
clear_tensor_data (ln_out_total )
730
751
731
- # Synchronize tensor parallel communication
752
+ # Make sure all tensor- parallel communication is finished
732
753
if ln_out_total_work is not None :
733
754
ln_out_total_work .wait ()
734
755
ln_out_total_work = None
0 commit comments