Skip to content

[TP] reorder MXFP8 wrapper over DTensor#4010

Open
pianpwk wants to merge 8 commits intopytorch:mainfrom
pianpwk:mxfp8-tp-fix
Open

[TP] reorder MXFP8 wrapper over DTensor#4010
pianpwk wants to merge 8 commits intopytorch:mainfrom
pianpwk:mxfp8-tp-fix

Conversation

@pianpwk
Copy link

@pianpwk pianpwk commented Mar 5, 2026

For pytorch/pytorch#177059, #3985

The original attempt to handle TP + MXFP8 wrapped DTensor over the MXFP8 subclass. The MXFP8 subclass intends to capture at the torch_function level, and use custom autograd functions to control fwd/bwd behavior. Because DTensor has no fwd/bwd coupling, and because it CIA-decomposes aten::linear, this ordering does not work; MXFP8 tensor does not see aten::linear at torch_function

This PR reverses the order to MXFP8(DTensor), allowing aten::linear interception and fwd/bwd control. Relies on pytorch/pytorch#177234 landing in pytorch.

I understand other dtypes still need reordering?

More details in discussion of pytorch/pytorch#177059

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 5, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4010

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f31cd57 with merge base 1d75a07 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 5, 2026
@danielvegamyhre
Copy link
Contributor

thanks for investigating this further @pianpwk ! i added a couple notes to assist

@pianpwk
Copy link
Author

pianpwk commented Mar 10, 2026

@danielvegamyhre updated the PR to cover more spots where we should reorder. I didn't review carefully but at a high-level it should make more sense?

Also had to override a DTensor sharding rule for scaled_mm - I'll upstream that to pytorch instead, if it turns out to be valid.

@danielvegamyhre
Copy link
Contributor

nice @pianpwk, qq, why did you need to register a custom sharding strategy for scaled_mm, does the one in core have a bug?

run_check=False,
shape=data_lp.size(),
stride=data_lp.stride(),
)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consequence of reversing order?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, this would rewrap as "Dtensor(MXTensor(...))" which the opposite order of what we are doing now. nice that all this can be removed now, cleaner

@pianpwk pianpwk changed the title fix tensor parallelism by reordering subclass wrapping [TP] reorder MXFP8 wrapper over DTensor Mar 12, 2026
@pianpwk pianpwk marked this pull request as ready for review March 12, 2026 06:20
@danielvegamyhre danielvegamyhre added this to the MXFP8 Training milestone Mar 12, 2026
@pianpwk pianpwk added the module: training quantize_ api training flow label Mar 13, 2026
pianpwk added a commit to pytorch/pytorch that referenced this pull request Mar 13, 2026
also for pytorch/ao#4010, adds custom handler to check local tensor

[ghstack-poisoned]
pianpwk added a commit to pytorch/pytorch that referenced this pull request Mar 13, 2026
also for pytorch/ao#4010, adds custom handler to check local tensor

[ghstack-poisoned]
pianpwk added a commit to pytorch/pytorch that referenced this pull request Mar 13, 2026
For pytorch/ao#4010, the existing strategy incorrect copies scaled_mm input strategies (2d) onto the scale tensor (1d)

[ghstack-poisoned]
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Mar 14, 2026
For pytorch/ao#4010, the existing strategy incorrect copies scaled_mm input strategies (2d) onto the scale tensor (1d)
Pull Request resolved: #177234
Approved by: https://github.com/danielvegamyhre
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Mar 14, 2026
also for pytorch/ao#4010, adds custom handler to check local tensor
Pull Request resolved: #177235
Approved by: https://github.com/Skylion007, https://github.com/wconstab
ghstack dependencies: #177234
Copy link
Contributor

@danielvegamyhre danielvegamyhre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, looks much cleaner now with that dtensor 1d scale sharding fix landed in core! couple minor comments

del t
tests.append(_test_mxfp8_mlp_tensor_parallelism_auto)
except Exception:
print("Skipping auto test: mxfp8_quantize CUDA kernel not available")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if _mxfp8_cuda_kernels_available let's just append the test case without doing the kernel dispatch test


# For MXFP8: parallelize first, then quantize.
# This puts MXFP8 wrapper on top of DTensor so __torch_function__
# intercepts F.linear before DTensor can trigger premature all-gathers.
Copy link
Contributor

@danielvegamyhre danielvegamyhre Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo this wrapping order change to MXFP8WeightWrapperTensor(Dtensor(...)) is pretty fundamental one, if you'd like to review as well. see the comment above, it is working cleanly now after @pianpwk landed a fix to dtensor sharding rules for 1d/flattened scale factors for torch._scaled_mm: pytorch/pytorch#177234

pianpwk added 8 commits March 19, 2026 00:11
When DTensor wraps MXFP8 (quantize first, then parallelize), DTensor
dispatches first on F.linear and performs premature Shard→Replicate
all-gather before MXFP8's __torch_function__ can intercept, causing
both ranks to see identical full weights and producing wrong numerics.

Fix: reverse the wrapping order for MXFP8 so MXFP8 sits on top of
DTensor (parallelize first, then quantize). MXFP8's __torch_function__
intercepts F.linear first, unwraps to get the DTensor via a
differentiable unwrap_weight() helper, and DTensor handles sharding
at the aten op level.

Changes:
- tensor.py: add scatter_ to preserved ops (TP weight distribution),
  fix pin_memory for DTensor, narrow linear override to func name
  "linear" only, use unwrap_weight(B) in both grouped_mm and linear
  paths, add _UnwrapWeight autograd function
- utils.py: transpose DTensor placements in _to_mxfp8_dim1_kernel_wrapper
  to match transposed local data
- dtensor_utils.py: make parallelize/quantize order conditional on
  config type (MXFP8: parallelize first; Float8: quantize first),
  use SQNR-based assertions for MXFP8, bf16 model/inputs
- test_mx_dtensor.py: update to MXFP8 (was FP4), split into emulated
  and auto tests with CUDA kernel availability guard
run_check=False,
shape=data_lp.size(),
stride=data_lp.stride(),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, this would rewrap as "Dtensor(MXTensor(...))" which the opposite order of what we are doing now. nice that all this can be removed now, cleaner

to_blocked,
in_placements=(t_placements,),
out_placements=t_placements,
)(t)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to confirm my understanding, local_map just runs the function (to_blocked) on each local shard as if it were a plain tensor, and then rewraps the output in a dtensor according to out_placements right

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep!

rule_shard_dim1 = (
[Replicate(), Shard(1), Replicate(), Shard(0)],
[Shard(1)] + non_tensor_args,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did these dim1 quantization kernel sharding rules need to be updated?

also, i thought the rule tuple order was (inputs, outputs) but it seems like this is the opposite, do i have it backwards?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pianpwk pianpwk requested a review from danielvegamyhre March 19, 2026 19:14
@danielvegamyhre
Copy link
Contributor

@pianpwk when i use latest torch nightly, checkout this PR and run ./test/prototype/mx_formats/test_mx_dtensor.sh the test fails with:

[rank0]:   File "/home/dev/ao/torchao/prototype/mx_formats/mx_tensor.py", line 409, in to_dtype
[rank0]:     data_hp = data_hp * s_fp
[rank0]:               ~~~~~~~~^~~~~~
[rank0]: RuntimeError: The size of tensor a (8) must match the size of tensor b (4) at non-singleton dimension 0

@pianpwk
Copy link
Author

pianpwk commented Mar 19, 2026

./test/prototype/mx_formats/test_mx_dtensor.sh

@pianpwk when i use latest torch nightly, checkout this PR and run ./test/prototype/mx_formats/test_mx_dtensor.sh the test fails with:

[rank0]:   File "/home/dev/ao/torchao/prototype/mx_formats/mx_tensor.py", line 409, in to_dtype
[rank0]:     data_hp = data_hp * s_fp
[rank0]:               ~~~~~~~~^~~~~~
[rank0]: RuntimeError: The size of tensor a (8) must match the size of tensor b (4) at non-singleton dimension 0

hmm not able to repro this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants