[not for land] [mxfp8 training] fix TP bug#3985
[not for land] [mxfp8 training] fix TP bug#3985danielvegamyhre wants to merge 2 commits intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3985
Note: Links to docs will display an error until the docs builds have been completed. ❌ 6 New FailuresAs of commit 3dd03f5 with merge base b8708a2 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
a5f6742 to
b01d0df
Compare
b01d0df to
f5cab89
Compare
|
On my side the [LINEAR rank=*] logs are only trigged from the non-parallelized model; this patch gets SQNR from 23 -> 50: +++ b/torchao/prototype/moe_training/tensor.py
@@ -21,6 +21,7 @@ from torchao.prototype.moe_training.config import (
)
from torchao.prototype.moe_training.utils import _quantize_then_scaled_grouped_mm
from torchao.prototype.mx_formats.mx_linear import _to_mxfp8_then_scaled_mm
+from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.utils import TorchAOBaseTensor
aten = torch.ops.aten
return _to_mxfp8_then_scaled_mm(
@@ -357,6 +359,37 @@ class MXFP8TrainingWeightWrapperTensor(TrainingWeightWrapperBaseTensor):
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
+ @classmethod
+ def __torch_dispatch__(cls, func, types, args, kwargs={}):
+ # Intercept aten.mm.default to apply MXFP8 quantization.
+ # This is needed because DTensor decomposes F.linear (a CompositeImplicitAutograd op)
+ # into aten.t + aten.mm, bypassing our __torch_function__ override for "linear".
+ # Without this, the TP/SP model would compute with regular BF16 mm.
+ if func == torch.ops.aten.mm.default:
+ A, B = args[0], args[1]
+
+ if not isinstance(A, cls) and isinstance(B, cls):
+ config = B.config
+
+ if isinstance(config, MXFP8TrainingOpConfig):
+ input_hp = A
+ # B._data is the transposed weight (from aten.t in linear decomposition).
+ # _to_mxfp8_then_scaled_mm expects weight in [out, in] layout and
+ # internally computes input @ weight.t(), so we pass B._data.t()
+ # to recover the original weight layout:
+ # input @ (B._data.t()).t() = input @ B._data
+ # which matches the aten.mm(input, weight_t) semantics.
+ weight_hp = B._data.t().contiguous()
+ return _to_mxfp8_then_scaled_mm(
+ input_hp,
+ weight_hp,
+ kernel_preference=config.kernel_preference,
+ scale_calculation_mode=config.scale_calculation_mode,
+ wgrad_with_hp=config.wgrad_with_hp,
+ )
+
+ return super().__torch_dispatch__(func, types, args, kwargs)
+ |
thanks for looking at this @pianpwk, knowing that linear is decomposed into aten.t + aten.mm is super useful. quick question, i thought torch_dispatch runs below autograd so we can't run autograd functions there otherwise it won't run backwards properly? is that not the case? this is why i have the autograd functions in torch_function instead |
that makes sense, but in the past when i tried to use autograd functions in torch_dispatch, i found only the forward pass ran, never the backward pass, which is problematic for this use case since we need to control the backward pass as well. my understand was that autograd is not captured at the __torch_dispatch__level, am i mistaken on that? when Dtensor decomposes linear into aten.t + aten.mm, does it skip torch_function entirely and go straight to dispatch? there is no way to intercept in torch_function? |
|
sorry, I think I was also partly mistaken. I'm not 100% sure what causes the (linear -> t + view + mm) decomposition for the DTensor + MXFP8 composition case, but you should also be able to intercept mm for it at the torch_function level as well. Let me dig a bit more... |
ok thanks again for your help - for what it's worth i've been trying to figure out how to intercept mm at torch dispatch level as well for TP case but i don't see any meaningful func names, just |
my bad, this only triggered because I had a custom dispatch mode on (DebugMode), otherwise it's at dispatch level. I get your point about autograd though, let me ask around. I tried reversing (parallelize before quantize) but that hit other composability issues. |
sounds good, thanks again - i will keep debugging on my side as well |
|
@danielvegamyhre I vibecoded this reordering fix which seems to work? #4010 |
@pianpwk per our discussion on that PR and elsewhere, do you know of a way to intercept the linear op when the wrapping is DTensor(MXFP8TrainingTensor(..))? that is the fundamental issue i am trying to solve, from logging it seems to be decomposed into
|
dde1cab to
3dd03f5
Compare
andrewor14
left a comment
There was a problem hiding this comment.
stamping to unblock CI, please address the comments before landing
| @classmethod | ||
| def __torch_function__(cls, func, types, args, kwargs={}): | ||
| # grouped_mm op override | ||
| print("[TORCH_FUNCTION]", func.__name__) |
There was a problem hiding this comment.
remove these before landing?
| ) | ||
| assert data_hp.is_contiguous(), "unsupported" | ||
| if not data_hp.is_contiguous(): | ||
| assert data_hp.is_contiguous(), "unsupported" |
There was a problem hiding this comment.
a bit confused by this, if it's not contiguous it would fail like before, so is there a reason behind this change?
There was a problem hiding this comment.
@andrewor14 sorry i linked the wrong PR, this is not the one that will address test failures, this is a WIP draft for an issue we are still trying to find a proper solution for - please disregard this PR.
| torch.ops.aten.transpose.int, | ||
| torch.ops.aten.t.default, | ||
| # required for TP - scatter_ is used to distribute weights | ||
| torch.ops.c10d.scatter_.default, |
There was a problem hiding this comment.
so this is the real fix right? Do we need the other ops you commented out?
No description provided.