Skip to content

[not for land] [mxfp8 training] fix TP bug#3985

Open
danielvegamyhre wants to merge 2 commits intomainfrom
tpmarch3
Open

[not for land] [mxfp8 training] fix TP bug#3985
danielvegamyhre wants to merge 2 commits intomainfrom
tpmarch3

Conversation

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Mar 4, 2026

No description provided.

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 4, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3985

Note: Links to docs will display an error until the docs builds have been completed.

❌ 6 New Failures

As of commit 3dd03f5 with merge base b8708a2 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 4, 2026
@pianpwk
Copy link

pianpwk commented Mar 4, 2026

On my side the [LINEAR rank=*] logs are only trigged from the non-parallelized model; this patch gets SQNR from 23 -> 50:

+++ b/torchao/prototype/moe_training/tensor.py
@@ -21,6 +21,7 @@ from torchao.prototype.moe_training.config import (
 )
 from torchao.prototype.moe_training.utils import _quantize_then_scaled_grouped_mm
 from torchao.prototype.mx_formats.mx_linear import _to_mxfp8_then_scaled_mm
+from torchao.prototype.mx_formats.mx_tensor import MXTensor
 from torchao.utils import TorchAOBaseTensor
 
 aten = torch.ops.aten
             
             return _to_mxfp8_then_scaled_mm(
@@ -357,6 +359,37 @@ class MXFP8TrainingWeightWrapperTensor(TrainingWeightWrapperBaseTensor):
             with torch._C.DisableTorchFunctionSubclass():
                 return func(*args, **kwargs)
 
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args, kwargs={}):
+        # Intercept aten.mm.default to apply MXFP8 quantization.
+        # This is needed because DTensor decomposes F.linear (a CompositeImplicitAutograd op)
+        # into aten.t + aten.mm, bypassing our __torch_function__ override for "linear".
+        # Without this, the TP/SP model would compute with regular BF16 mm.
+        if func == torch.ops.aten.mm.default:
+            A, B = args[0], args[1]
+
+            if not isinstance(A, cls) and isinstance(B, cls):
+                config = B.config
+
+                if isinstance(config, MXFP8TrainingOpConfig):
+                    input_hp = A
+                    # B._data is the transposed weight (from aten.t in linear decomposition).
+                    # _to_mxfp8_then_scaled_mm expects weight in [out, in] layout and
+                    # internally computes input @ weight.t(), so we pass B._data.t()
+                    # to recover the original weight layout:
+                    #   input @ (B._data.t()).t() = input @ B._data
+                    # which matches the aten.mm(input, weight_t) semantics.
+                    weight_hp = B._data.t().contiguous()
+                    return _to_mxfp8_then_scaled_mm(
+                        input_hp,
+                        weight_hp,
+                        kernel_preference=config.kernel_preference,
+                        scale_calculation_mode=config.scale_calculation_mode,
+                        wgrad_with_hp=config.wgrad_with_hp,
+                    )
+
+        return super().__torch_dispatch__(func, types, args, kwargs)
+

@danielvegamyhre
Copy link
Contributor Author

On my side the [LINEAR rank=*] logs are only trigged from the non-parallelized model; this patch gets SQNR from 23 -> 50:

+++ b/torchao/prototype/moe_training/tensor.py
@@ -21,6 +21,7 @@ from torchao.prototype.moe_training.config import (
 )
 from torchao.prototype.moe_training.utils import _quantize_then_scaled_grouped_mm
 from torchao.prototype.mx_formats.mx_linear import _to_mxfp8_then_scaled_mm
+from torchao.prototype.mx_formats.mx_tensor import MXTensor
 from torchao.utils import TorchAOBaseTensor
 
 aten = torch.ops.aten
             
             return _to_mxfp8_then_scaled_mm(
@@ -357,6 +359,37 @@ class MXFP8TrainingWeightWrapperTensor(TrainingWeightWrapperBaseTensor):
             with torch._C.DisableTorchFunctionSubclass():
                 return func(*args, **kwargs)
 
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args, kwargs={}):
+        # Intercept aten.mm.default to apply MXFP8 quantization.
+        # This is needed because DTensor decomposes F.linear (a CompositeImplicitAutograd op)
+        # into aten.t + aten.mm, bypassing our __torch_function__ override for "linear".
+        # Without this, the TP/SP model would compute with regular BF16 mm.
+        if func == torch.ops.aten.mm.default:
+            A, B = args[0], args[1]
+
+            if not isinstance(A, cls) and isinstance(B, cls):
+                config = B.config
+
+                if isinstance(config, MXFP8TrainingOpConfig):
+                    input_hp = A
+                    # B._data is the transposed weight (from aten.t in linear decomposition).
+                    # _to_mxfp8_then_scaled_mm expects weight in [out, in] layout and
+                    # internally computes input @ weight.t(), so we pass B._data.t()
+                    # to recover the original weight layout:
+                    #   input @ (B._data.t()).t() = input @ B._data
+                    # which matches the aten.mm(input, weight_t) semantics.
+                    weight_hp = B._data.t().contiguous()
+                    return _to_mxfp8_then_scaled_mm(
+                        input_hp,
+                        weight_hp,
+                        kernel_preference=config.kernel_preference,
+                        scale_calculation_mode=config.scale_calculation_mode,
+                        wgrad_with_hp=config.wgrad_with_hp,
+                    )
+
+        return super().__torch_dispatch__(func, types, args, kwargs)
+

thanks for looking at this @pianpwk, knowing that linear is decomposed into aten.t + aten.mm is super useful. quick question, i thought torch_dispatch runs below autograd so we can't run autograd functions there otherwise it won't run backwards properly? is that not the case? this is why i have the autograd functions in torch_function instead

@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Mar 4, 2026

@pianpwk

it needs to intercept aten::mm at the torch_dispatch level instead.

that makes sense, but in the past when i tried to use autograd functions in torch_dispatch, i found only the forward pass ran, never the backward pass, which is problematic for this use case since we need to control the backward pass as well. my understand was that autograd is not captured at the __torch_dispatch__level, am i mistaken on that?

when Dtensor decomposes linear into aten.t + aten.mm, does it skip torch_function entirely and go straight to dispatch? there is no way to intercept in torch_function?

@pianpwk
Copy link

pianpwk commented Mar 4, 2026

sorry, I think I was also partly mistaken. I'm not 100% sure what causes the (linear -> t + view + mm) decomposition for the DTensor + MXFP8 composition case, but you should also be able to intercept mm for it at the torch_function level as well.

Let me dig a bit more...

@danielvegamyhre
Copy link
Contributor Author

sorry, I think I was also partly mistaken. I don't think DTensor CIA handling is why the (linear -> t + view + mm) decomposition happens anymore, and you should also be able to intercept mm at the torch_function level as well.

Let me dig a bit more...

ok thanks again for your help - for what it's worth i've been trying to figure out how to intercept mm at torch dispatch level as well for TP case but i don't see any meaningful func names, just __get__ at function level followed by mm.default at dispatch level:

TP case
[DISPATCH rank=0] aten.t.default args=['Wrapper(torch.Size([64, 64]), mean=-0.001793)']
func __get__
[DISPATCH rank=0] aten.mm.default args=['Tensor(torch.Size([256, 64]), mean=0.496094)', 'Wrapper(torch.Size([64, 64]), mean=-0.001793)']
not preserving subclass aten.mm.default
[DISPATCH rank=0] aten.t.default args=['Wrapper(torch.Size([64, 64]), mean=-0.000147)']
func __get__
[DISPATCH rank=0] aten.mm.default args=['Tensor(torch.Size([256, 64]), mean=0.496094)', 'Wrapper(torch.Size([64, 64]), mean=-0.000147)']
not preserving subclass aten.mm.default
[DISPATCH rank=0] aten.t.default args=['Wrapper(torch.Size([64, 64]), mean=0.000112)']
func __get__
[DISPATCH rank=0] aten.mm.default args=['Tensor(torch.Size([256, 64]), mean=-0.000790)', 'Wrapper(torch.Size([64, 64]), mean=0.000112)']
not preserving subclass aten.mm.default

@pianpwk
Copy link

pianpwk commented Mar 4, 2026

but i don't see any meaningful func names, just get at function level followed by mm.default at dispatch level:

my bad, this only triggered because I had a custom dispatch mode on (DebugMode), otherwise it's at dispatch level. I get your point about autograd though, let me ask around. I tried reversing (parallelize before quantize) but that hit other composability issues.

@danielvegamyhre
Copy link
Contributor Author

but i don't see any meaningful func names, just get at function level followed by mm.default at dispatch level:

my bad, this only triggered because I had a custom dispatch mode on (DebugMode), otherwise it's at dispatch level. I get your point about autograd though, let me ask around. I tried reversing (parallelize before quantize) but that hit other composability issues.

sounds good, thanks again - i will keep debugging on my side as well

@pianpwk
Copy link

pianpwk commented Mar 5, 2026

@danielvegamyhre I vibecoded this reordering fix which seems to work? #4010

@danielvegamyhre
Copy link
Contributor Author

@danielvegamyhre I vibecoded this reordering fix which seems to work? #4010

@pianpwk per our discussion on that PR and elsewhere, do you know of a way to intercept the linear op when the wrapping is DTensor(MXFP8TrainingTensor(..))? that is the fundamental issue i am trying to solve, from logging it seems to be decomposed into

  • __get__ (not sure what this is)
  • t
  • mm

Copy link
Contributor

@andrewor14 andrewor14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamping to unblock CI, please address the comments before landing

@classmethod
def __torch_function__(cls, func, types, args, kwargs={}):
# grouped_mm op override
print("[TORCH_FUNCTION]", func.__name__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove these before landing?

)
assert data_hp.is_contiguous(), "unsupported"
if not data_hp.is_contiguous():
assert data_hp.is_contiguous(), "unsupported"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bit confused by this, if it's not contiguous it would fail like before, so is there a reason behind this change?

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andrewor14 sorry i linked the wrong PR, this is not the one that will address test failures, this is a WIP draft for an issue we are still trying to find a proper solution for - please disregard this PR.

torch.ops.aten.transpose.int,
torch.ops.aten.t.default,
# required for TP - scatter_ is used to distribute weights
torch.ops.c10d.scatter_.default,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this is the real fix right? Do we need the other ops you commented out?

@danielvegamyhre danielvegamyhre changed the title [mxfp8 training] fix TP bug [not for land] [mxfp8 training] fix TP bug Mar 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow mx

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants