-
Notifications
You must be signed in to change notification settings - Fork 463
[TP] reorder MXFP8 wrapper over DTensor #4010
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6dd3783
162ee76
c44d9d3
d67ca9c
762ba23
1dec8aa
a2c4af2
f31cd57
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,8 @@ | |
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from torch.distributed._tensor import DTensor | ||
| from torch.distributed.tensor import DTensor, Replicate, Shard | ||
| from torch.distributed.tensor.experimental import local_map | ||
| from torch.utils._python_dispatch import ( | ||
| return_and_correct_aliasing, | ||
| ) | ||
|
|
@@ -353,7 +354,7 @@ def get_fp_scale(scale_e8m0): | |
| s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS | ||
| # TODO(later): it would be nice if there was a way to do the 2^x operation | ||
| # in PyTorch without creating a tensor of twos | ||
| two = torch.full(s_offset.size(), 2.0, device=scale_e8m0.device) | ||
| two = torch.full_like(s_offset, 2.0, dtype=torch.float32) | ||
| # pow(two, s_offset) can be out of range of floating point formats. | ||
| # TODO(later): handle this for float16 if we decide to support float16 | ||
| # scales. | ||
|
|
@@ -560,35 +561,6 @@ def from_qdata_and_scales( | |
| ) | ||
| elem_dtype = qdata.dtype | ||
|
|
||
| if isinstance(qdata, DTensor) or isinstance(scales, DTensor): | ||
| assert isinstance(qdata, DTensor) and isinstance(scales, DTensor), ( | ||
| "qdata and scales must either both be DTensors or both be local tensors" | ||
| ) | ||
| assert qdata.device_mesh == scales.device_mesh, ( | ||
| "qdata and scales DTensors must have the same device mesh" | ||
| ) | ||
| assert qdata.placements == scales.placements, ( | ||
| "qdata and scales DTensors must have the same placements" | ||
| ) | ||
| inner_mx_tensor = MXTensor( | ||
| qdata.to_local(), | ||
| scales.to_local(), | ||
| elem_dtype, | ||
| block_size, | ||
| orig_dtype, | ||
| kernel_preference, | ||
| act_quant_kwargs, | ||
| is_swizzled_scales, | ||
| ) | ||
| return DTensor.from_local( | ||
| inner_mx_tensor, | ||
| qdata.device_mesh, | ||
| qdata.placements, | ||
| run_check=False, | ||
| shape=qdata.size(), | ||
| stride=qdata.stride(), | ||
| ) | ||
|
|
||
| return MXTensor( | ||
| qdata, | ||
| scales, | ||
|
|
@@ -640,28 +612,6 @@ def to_mx( | |
| inner_block_size=block_size, | ||
| scaling_mode=scaling_mode.value, | ||
| ) | ||
| if isinstance(scale_e8m0_biased, DTensor): | ||
| assert isinstance(data_lp, DTensor), "unsupported" | ||
| local_scale_e8m0_biased = scale_e8m0_biased.to_local() | ||
| local_data_lp = data_lp.to_local() | ||
| inner_mx_tensor = MXTensor( | ||
| local_data_lp, | ||
| local_scale_e8m0_biased, | ||
| elem_dtype, | ||
| block_size, | ||
| data_hp.dtype, | ||
| kernel_preference, | ||
| act_quant_kwargs, | ||
| is_swizzled_scales, | ||
| ) | ||
| return DTensor.from_local( | ||
| inner_mx_tensor, | ||
| data_lp.device_mesh, | ||
| data_lp.placements, | ||
| run_check=False, | ||
| shape=data_lp.size(), | ||
| stride=data_lp.stride(), | ||
| ) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consequence of reversing order?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. makes sense, this would rewrap as "Dtensor(MXTensor(...))" which the opposite order of what we are doing now. nice that all this can be removed now, cleaner |
||
| return MXTensor( | ||
| data_lp, | ||
| scale_e8m0_biased, | ||
|
|
@@ -703,6 +653,29 @@ def _get_gemm_choice( | |
| return choice_a if choice_a is not None else choice_b | ||
|
|
||
|
|
||
| def maybe_dtensor_to_blocked(t: torch.Tensor) -> torch.Tensor: | ||
| # redistribute to Replicate or Shard(0); to_blocked will view/permute/flatten into a 1d tensor | ||
| # sharding is only preservable on the first dimension. | ||
| if isinstance(t, DTensor): | ||
| t_placements = [ | ||
| x if x in (Replicate(), Shard(0)) else Replicate() for x in t.placements | ||
| ] | ||
| if t_placements != t.placements: # can't perform collectives in float8 | ||
| t = ( | ||
| t.view(torch.uint8) | ||
| .redistribute(placements=t_placements) | ||
| .view(torch.float8_e8m0fnu) | ||
| ) | ||
| out = local_map( | ||
| to_blocked, | ||
| in_placements=(t_placements,), | ||
| out_placements=t_placements, | ||
| )(t) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to confirm my understanding, local_map just runs the function (to_blocked) on each local shard as if it were a plain tensor, and then rewraps the output in a dtensor according to out_placements right
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep! |
||
| else: | ||
| out = to_blocked(t) | ||
| return out | ||
|
|
||
|
|
||
| def _addmm_mx_dispatch( | ||
| a: torch.Tensor, b: MXTensor, aten_op, bias: Optional[torch.Tensor] = None | ||
| ) -> torch.Tensor: | ||
|
|
@@ -737,13 +710,13 @@ def _addmm_mx_dispatch( | |
| a_scale_block = a.scale | ||
| else: | ||
| a_scale = a.scale.view(M, K // a.block_size) | ||
| a_scale_block = to_blocked(a_scale) | ||
| a_scale_block = maybe_dtensor_to_blocked(a_scale) | ||
|
|
||
| if b.is_swizzled_scales: | ||
| b_scale_block = b.scale.t() | ||
| else: | ||
| b_scale = b.scale.t().view(N, K // b.block_size) | ||
| b_scale_block = to_blocked(b_scale) | ||
| b_scale_block = maybe_dtensor_to_blocked(b_scale) | ||
|
|
||
| if a.elem_dtype == torch.float8_e4m3fn: | ||
| assert b.elem_dtype == torch.float8_e4m3fn | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why did these dim1 quantization kernel sharding rules need to be updated?
also, i thought the rule tuple order was (inputs, outputs) but it seems like this is the opposite, do i have it backwards?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh ordering is [*outputs, *inputs], I think that's the issue: https://github.com/pytorch/pytorch/blob/6511db6ea2ae6133a5260076862f63540564897a/torch/distributed/tensor/experimental/_register_sharding.py#L22