[TP] reorder MXFP8 wrapper over DTensor#4010
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4010
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f31cd57 with merge base 1d75a07 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
thanks for investigating this further @pianpwk ! i added a couple notes to assist |
|
@danielvegamyhre updated the PR to cover more spots where we should reorder. I didn't review carefully but at a high-level it should make more sense? Also had to override a DTensor sharding rule for scaled_mm - I'll upstream that to pytorch instead, if it turns out to be valid. |
|
nice @pianpwk, qq, why did you need to register a custom sharding strategy for scaled_mm, does the one in core have a bug? |
| run_check=False, | ||
| shape=data_lp.size(), | ||
| stride=data_lp.stride(), | ||
| ) |
There was a problem hiding this comment.
consequence of reversing order?
There was a problem hiding this comment.
makes sense, this would rewrap as "Dtensor(MXTensor(...))" which the opposite order of what we are doing now. nice that all this can be removed now, cleaner
also for pytorch/ao#4010, adds custom handler to check local tensor [ghstack-poisoned]
also for pytorch/ao#4010, adds custom handler to check local tensor [ghstack-poisoned]
For pytorch/ao#4010, the existing strategy incorrect copies scaled_mm input strategies (2d) onto the scale tensor (1d) [ghstack-poisoned]
For pytorch/ao#4010, the existing strategy incorrect copies scaled_mm input strategies (2d) onto the scale tensor (1d) Pull Request resolved: #177234 Approved by: https://github.com/danielvegamyhre
also for pytorch/ao#4010, adds custom handler to check local tensor Pull Request resolved: #177235 Approved by: https://github.com/Skylion007, https://github.com/wconstab ghstack dependencies: #177234
danielvegamyhre
left a comment
There was a problem hiding this comment.
nice, looks much cleaner now with that dtensor 1d scale sharding fix landed in core! couple minor comments
| del t | ||
| tests.append(_test_mxfp8_mlp_tensor_parallelism_auto) | ||
| except Exception: | ||
| print("Skipping auto test: mxfp8_quantize CUDA kernel not available") |
There was a problem hiding this comment.
if _mxfp8_cuda_kernels_available let's just append the test case without doing the kernel dispatch test
|
|
||
| # For MXFP8: parallelize first, then quantize. | ||
| # This puts MXFP8 wrapper on top of DTensor so __torch_function__ | ||
| # intercepts F.linear before DTensor can trigger premature all-gathers. |
There was a problem hiding this comment.
@vkuzo this wrapping order change to MXFP8WeightWrapperTensor(Dtensor(...)) is pretty fundamental one, if you'd like to review as well. see the comment above, it is working cleanly now after @pianpwk landed a fix to dtensor sharding rules for 1d/flattened scale factors for torch._scaled_mm: pytorch/pytorch#177234
When DTensor wraps MXFP8 (quantize first, then parallelize), DTensor dispatches first on F.linear and performs premature Shard→Replicate all-gather before MXFP8's __torch_function__ can intercept, causing both ranks to see identical full weights and producing wrong numerics. Fix: reverse the wrapping order for MXFP8 so MXFP8 sits on top of DTensor (parallelize first, then quantize). MXFP8's __torch_function__ intercepts F.linear first, unwraps to get the DTensor via a differentiable unwrap_weight() helper, and DTensor handles sharding at the aten op level. Changes: - tensor.py: add scatter_ to preserved ops (TP weight distribution), fix pin_memory for DTensor, narrow linear override to func name "linear" only, use unwrap_weight(B) in both grouped_mm and linear paths, add _UnwrapWeight autograd function - utils.py: transpose DTensor placements in _to_mxfp8_dim1_kernel_wrapper to match transposed local data - dtensor_utils.py: make parallelize/quantize order conditional on config type (MXFP8: parallelize first; Float8: quantize first), use SQNR-based assertions for MXFP8, bf16 model/inputs - test_mx_dtensor.py: update to MXFP8 (was FP4), split into emulated and auto tests with CUDA kernel availability guard
| run_check=False, | ||
| shape=data_lp.size(), | ||
| stride=data_lp.stride(), | ||
| ) |
There was a problem hiding this comment.
makes sense, this would rewrap as "Dtensor(MXTensor(...))" which the opposite order of what we are doing now. nice that all this can be removed now, cleaner
| to_blocked, | ||
| in_placements=(t_placements,), | ||
| out_placements=t_placements, | ||
| )(t) |
There was a problem hiding this comment.
to confirm my understanding, local_map just runs the function (to_blocked) on each local shard as if it were a plain tensor, and then rewraps the output in a dtensor according to out_placements right
| rule_shard_dim1 = ( | ||
| [Replicate(), Shard(1), Replicate(), Shard(0)], | ||
| [Shard(1)] + non_tensor_args, | ||
| ) |
There was a problem hiding this comment.
why did these dim1 quantization kernel sharding rules need to be updated?
also, i thought the rule tuple order was (inputs, outputs) but it seems like this is the opposite, do i have it backwards?
There was a problem hiding this comment.
Ahh ordering is [*outputs, *inputs], I think that's the issue: https://github.com/pytorch/pytorch/blob/6511db6ea2ae6133a5260076862f63540564897a/torch/distributed/tensor/experimental/_register_sharding.py#L22
|
@pianpwk when i use latest torch nightly, checkout this PR and run |
hmm not able to repro this |
For pytorch/pytorch#177059, #3985
The original attempt to handle TP + MXFP8 wrapped DTensor over the MXFP8 subclass. The MXFP8 subclass intends to capture at the torch_function level, and use custom autograd functions to control fwd/bwd behavior. Because DTensor has no fwd/bwd coupling, and because it CIA-decomposes aten::linear, this ordering does not work; MXFP8 tensor does not see aten::linear at torch_function
This PR reverses the order to MXFP8(DTensor), allowing aten::linear interception and fwd/bwd control. Relies on pytorch/pytorch#177234 landing in pytorch.
I understand other dtypes still need reordering?
More details in discussion of pytorch/pytorch#177059