-
Notifications
You must be signed in to change notification settings - Fork 400
Description
This issue is related to pytorch/torchtitan#2160 , I was looking into this and I think we might need to update the MXTensor in order to be able to support the MXFP8 all gather for the Linear layers, it will help to improve the E2E model performance. I think we will need to add qdata_dim0 and scale_dim0 to MXTensor as below (We can update the naming as per feedback):
class MXTensor(TorchAOBaseTensor):
tensor_data_names = ["qdata", "scale", "qdata_dim0", "scale_dim0"]
tensor_attribute_names = [
"_elem_dtype",
"block_size",
"_orig_dtype",
"kernel_preference",
"act_quant_kwargs",
"_is_swizzled_scales",
]
self.qdata_dim0 = None or qdata_dim0 #Default
self.scale_dim0 = None or scale_dim0 #Default
We need something similar to above because the MXLinear layer (inherits from nn.linear) expects only one input in the forward and hence the backward also expects only one grad output. But in order to be able to perform MXFP8 AG we will need:
Quantize Input data -> MXFP8 All Gather -> Matmul
In order for correct backward pass, we need the input to be quantized in the other dimension, for this we need the quantized input in other dim or the original input. But i think saving the original input is wasteful in terms of memory, instead we can save the quantized input in other dim. If we make this an attribute of the MXTensor, then we will not have the need to pass in 2 inputs and we will be able to successfully do the above mentioned operations. The same logic can be applied to the grad_output (we can quantize grad in both dim0 and dim1 and save in the MXTensor, then in the bwd pass, we can transpose the dim0 tensor and use for matmul directly) and then used in the backward. I was able to do a POC and get the MXFP8 Allgather in fwd and bwd for a MXLinear layer as well (needed more changes as well but this was the highlight)
I took a look at Nvidia Transformer Engine, and I see a similar design there as well: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/tensor/mxfp8_tensor.py#L226-L229
Want to know your thoughts @vkuzo