Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 46 additions & 15 deletions test/prototype/mx_formats/test_mx_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from tqdm import tqdm

from torchao.prototype.moe_training.config import MXFP8TrainingOpConfig
from torchao.prototype.moe_training.config import (
MXFP8TrainingOpConfig,
MXFP8TrainingRecipe,
)
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.testing.training.dtensor_utils import (
_test_lowp_mlp_tensor_parallelism_base,
Expand All @@ -43,52 +46,80 @@ def setup_distributed():
return device_mesh


def _test_dtensor_cast_to_mxfp4(mesh: DeviceMesh, size=4):
def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=1024):
device = mesh.device_type

x_fp32 = torch.rand(size, size, device=device)
x_fp4 = MXTensor.to_mx(x_fp32, torch.float4_e2m1fn_x2, block_size=size // 2)
x_fp8 = MXTensor.to_mx(x_fp32, torch.float8_e4m3fn, block_size=32)

dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)])
dist_x_fp4 = MXTensor.to_mx(
dist_x_fp32, torch.float4_e2m1fn_x2, block_size=size // 2
dist_x_fp8 = MXTensor.to_mx(dist_x_fp32, torch.float8_e4m3fn, block_size=32)

# With the new wrapping order, MXTensor is the outer wrapper with DTensor
# inner tensors (MXTensor(DTensor_qdata, DTensor_scale)).
assert isinstance(dist_x_fp8, MXTensor), (
f"Expected MXTensor, got {type(dist_x_fp8)}"
)
assert isinstance(dist_x_fp8.qdata, DTensor), (
f"Expected qdata to be DTensor, got {type(dist_x_fp8.qdata)}"
)
assert isinstance(dist_x_fp8.scale, DTensor), (
f"Expected scale to be DTensor, got {type(dist_x_fp8.scale)}"
)
assert isinstance(dist_x_fp4, DTensor)

# Verify that the result of to_mx with DTensor matches the slice of the
# result of to_mx without DTensor. This will fail on numeric op mismatches.
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
assert size % world_size == 0, "unsupported"
x_fp4_fp32 = x_fp4.dequantize(torch.float32)
x_fp8_fp32 = x_fp8.dequantize(torch.bfloat16)
rows_per_slice = size // world_size
slice_start = local_rank * rows_per_slice
slice_end = (local_rank + 1) * rows_per_slice
x_fp4_fp32_slice = x_fp4_fp32[slice_start:slice_end]
x_fp8_fp32_slice = x_fp8_fp32[slice_start:slice_end]
# dequantize handles DTensor inner tensors and returns a DTensor
dist_x_fp8_dequant = dist_x_fp8.dequantize(torch.bfloat16)
assert isinstance(dist_x_fp8_dequant, DTensor), (
f"Expected dequantize result to be DTensor, got {type(dist_x_fp8_dequant)}"
)
torch.testing.assert_close(
x_fp4_fp32_slice,
dist_x_fp4.to_local().dequantize(torch.float32),
x_fp8_fp32_slice,
dist_x_fp8_dequant.to_local(),
atol=0,
rtol=0,
)


def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128):
config = MXFP8TrainingOpConfig()
def _test_mxfp8_mlp_tensor_parallelism_emulated(mesh: DeviceMesh, size=64):
recipe = MXFP8TrainingRecipe("mxfp8_emulated_rceil")
config = MXFP8TrainingOpConfig.from_recipe(recipe)
_test_lowp_mlp_tensor_parallelism_base(
mesh, config, size, compile=False, allgather_in_lowp=False
)


def _test_mxfp8_mlp_tensor_parallelism_auto(mesh: DeviceMesh, size=64):
recipe = MXFP8TrainingRecipe("mxfp8_rceil")
config = MXFP8TrainingOpConfig.from_recipe(recipe)
_test_lowp_mlp_tensor_parallelism_base(
mesh, config, size, compile=True, allgather_in_lowp=False
mesh, config, size, compile=False, allgather_in_lowp=False
)


if __name__ == "__main__":
device_mesh = setup_distributed()
tests = [
_test_dtensor_cast_to_mxfp4,
_test_mxfp8_mlp_tensor_parallelism,
_test_dtensor_cast_to_mxfp8,
_test_mxfp8_mlp_tensor_parallelism_emulated,
]
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
_mxfp8_cuda_kernels_available,
)

if _mxfp8_cuda_kernels_available:
tests.append(_test_mxfp8_mlp_tensor_parallelism_auto)
else:
print("Skipping auto test: requires SM >= 100 and CUDA >= 12.8")

for test in tqdm(tests, desc="Running tests"):
try:
Expand Down
8 changes: 5 additions & 3 deletions torchao/prototype/moe_training/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
torch.ops.aten.clone.default,
torch.ops.aten.transpose.int,
torch.ops.aten.t.default,
# required for TP - scatter_ is used to distribute weights
torch.ops.c10d.scatter_.default,
}


Expand Down Expand Up @@ -288,16 +290,16 @@ def __torch_function__(cls, func, types, args, kwargs={}):
if A_is_2d and B_is_2d_or_3d and offs is not None:
return _quantize_then_scaled_grouped_mm(
A,
B,
unwrap_weight(B),
offs=offs,
config=config,
)

# linear op override
elif func.__name__ in ("linear", "mm", "matmul", "addmm"):
A, B = args[0], args[1]
assert not isinstance(A, cls), f"A should not be a {cls.__name__}"

assert not isinstance(A, cls), f"A should not be a {cls.__name__}"
assert isinstance(B, cls), f"B should be a {cls.__name__}"

config = B.config
Expand All @@ -307,7 +309,7 @@ def __torch_function__(cls, func, types, args, kwargs={}):

return _to_mxfp8_then_scaled_mm(
A,
B,
unwrap_weight(B),
kernel_preference=config.kernel_preference,
scale_calculation_mode=config.scale_calculation_mode,
wgrad_with_hp=config.wgrad_with_hp,
Expand Down
60 changes: 33 additions & 27 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,35 +1122,41 @@ def custom_mxfp8_quantize_cuda_dim1_sharding(
fp8_format: str,
scaling_mode: str,
):
# This function signature can be used to understand the shardings:
# _, colwise_data, _, colwise_scales = mxfp8_quantize_cuda(x, rowwise=False, colwise=True)

# When inputs and scale are replicated, we return a quantized output tensor (replicated).
inputs_replicated = [None, Replicate(), None, Replicate()]
outputs_replicated = [None, Replicate(), None, None]
rule_for_input_replicated = (
inputs_replicated,
outputs_replicated,
# Op returns 4 tensors: (output_rowwise, output_colwise, scales_rowwise, scales_colwise)
# When rowwise=False, outputs 0 and 2 are empty tensors (size 0).
# output_colwise has shape (rows, cols) in col-major order.
# scales_colwise has shape (cols, num_row_blocks) in col-major order.
#
# Format: (output_placements, input_placements)
# Input placements: one per arg (x=Tensor, then 6 non-tensor args=None)
# Output placements: one per output tensor (4 total)

non_tensor_args = [None, None, None, None, None, None]

# When input is replicated, all outputs are replicated.
rule_replicated = (
[Replicate(), Replicate(), Replicate(), Replicate()],
[Replicate()] + non_tensor_args,
)

# When inputs and scale are sharded along dim 0,
# we return a quantized output tensor (sharded along dim1 due to transpose).
inputs_sharded_dim0 = [None, Shard(0), None, Shard(0)]
outputs_sharded_dim1 = [None, Shard(1), None, None]
rule_for_input_sharded_dim0 = (inputs_sharded_dim0, outputs_sharded_dim1)

# When inputs and scale are sharded along dim 1,
# we return a quantized output tensor (sharded along dim0 due to transpose).
inputs_sharded_dim1 = [None, Shard(1), None, Shard(1)]
outputs_sharded_dim0 = [None, Shard(0), None, None]
rule_for_input_sharded_dim1 = (inputs_sharded_dim1, outputs_sharded_dim0)

acceptable_shardings = [
rule_for_input_replicated,
rule_for_input_sharded_dim0,
rule_for_input_sharded_dim1,
]
return acceptable_shardings
# When input is sharded along dim 0:
# output_colwise (rows, cols) col-major: rows are sharded → Shard(0)
# scales_colwise (cols, num_row_blocks) col-major: row blocks sharded → Shard(1)
# Unused rowwise outputs (empty tensors): Replicate()
rule_shard_dim0 = (
[Replicate(), Shard(0), Replicate(), Shard(1)],
[Shard(0)] + non_tensor_args,
)

# When input is sharded along dim 1:
# output_colwise: cols are sharded → Shard(1)
# scales_colwise: col dim is sharded → Shard(0)
rule_shard_dim1 = (
[Replicate(), Shard(1), Replicate(), Shard(0)],
[Shard(1)] + non_tensor_args,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did these dim1 quantization kernel sharding rules need to be updated?

also, i thought the rule tuple order was (inputs, outputs) but it seems like this is the opposite, do i have it backwards?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


return [rule_replicated, rule_shard_dim0, rule_shard_dim1]

else:

Expand Down
83 changes: 28 additions & 55 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@

import torch
import torch.nn.functional as F
from torch.distributed._tensor import DTensor
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed.tensor.experimental import local_map
from torch.utils._python_dispatch import (
return_and_correct_aliasing,
)
Expand Down Expand Up @@ -353,7 +354,7 @@ def get_fp_scale(scale_e8m0):
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS
# TODO(later): it would be nice if there was a way to do the 2^x operation
# in PyTorch without creating a tensor of twos
two = torch.full(s_offset.size(), 2.0, device=scale_e8m0.device)
two = torch.full_like(s_offset, 2.0, dtype=torch.float32)
# pow(two, s_offset) can be out of range of floating point formats.
# TODO(later): handle this for float16 if we decide to support float16
# scales.
Expand Down Expand Up @@ -560,35 +561,6 @@ def from_qdata_and_scales(
)
elem_dtype = qdata.dtype

if isinstance(qdata, DTensor) or isinstance(scales, DTensor):
assert isinstance(qdata, DTensor) and isinstance(scales, DTensor), (
"qdata and scales must either both be DTensors or both be local tensors"
)
assert qdata.device_mesh == scales.device_mesh, (
"qdata and scales DTensors must have the same device mesh"
)
assert qdata.placements == scales.placements, (
"qdata and scales DTensors must have the same placements"
)
inner_mx_tensor = MXTensor(
qdata.to_local(),
scales.to_local(),
elem_dtype,
block_size,
orig_dtype,
kernel_preference,
act_quant_kwargs,
is_swizzled_scales,
)
return DTensor.from_local(
inner_mx_tensor,
qdata.device_mesh,
qdata.placements,
run_check=False,
shape=qdata.size(),
stride=qdata.stride(),
)

return MXTensor(
qdata,
scales,
Expand Down Expand Up @@ -640,28 +612,6 @@ def to_mx(
inner_block_size=block_size,
scaling_mode=scaling_mode.value,
)
if isinstance(scale_e8m0_biased, DTensor):
assert isinstance(data_lp, DTensor), "unsupported"
local_scale_e8m0_biased = scale_e8m0_biased.to_local()
local_data_lp = data_lp.to_local()
inner_mx_tensor = MXTensor(
local_data_lp,
local_scale_e8m0_biased,
elem_dtype,
block_size,
data_hp.dtype,
kernel_preference,
act_quant_kwargs,
is_swizzled_scales,
)
return DTensor.from_local(
inner_mx_tensor,
data_lp.device_mesh,
data_lp.placements,
run_check=False,
shape=data_lp.size(),
stride=data_lp.stride(),
)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consequence of reversing order?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, this would rewrap as "Dtensor(MXTensor(...))" which the opposite order of what we are doing now. nice that all this can be removed now, cleaner

return MXTensor(
data_lp,
scale_e8m0_biased,
Expand Down Expand Up @@ -703,6 +653,29 @@ def _get_gemm_choice(
return choice_a if choice_a is not None else choice_b


def maybe_dtensor_to_blocked(t: torch.Tensor) -> torch.Tensor:
# redistribute to Replicate or Shard(0); to_blocked will view/permute/flatten into a 1d tensor
# sharding is only preservable on the first dimension.
if isinstance(t, DTensor):
t_placements = [
x if x in (Replicate(), Shard(0)) else Replicate() for x in t.placements
]
if t_placements != t.placements: # can't perform collectives in float8
t = (
t.view(torch.uint8)
.redistribute(placements=t_placements)
.view(torch.float8_e8m0fnu)
)
out = local_map(
to_blocked,
in_placements=(t_placements,),
out_placements=t_placements,
)(t)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to confirm my understanding, local_map just runs the function (to_blocked) on each local shard as if it were a plain tensor, and then rewraps the output in a dtensor according to out_placements right

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep!

else:
out = to_blocked(t)
return out


def _addmm_mx_dispatch(
a: torch.Tensor, b: MXTensor, aten_op, bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
Expand Down Expand Up @@ -737,13 +710,13 @@ def _addmm_mx_dispatch(
a_scale_block = a.scale
else:
a_scale = a.scale.view(M, K // a.block_size)
a_scale_block = to_blocked(a_scale)
a_scale_block = maybe_dtensor_to_blocked(a_scale)

if b.is_swizzled_scales:
b_scale_block = b.scale.t()
else:
b_scale = b.scale.t().view(N, K // b.block_size)
b_scale_block = to_blocked(b_scale)
b_scale_block = maybe_dtensor_to_blocked(b_scale)

if a.elem_dtype == torch.float8_e4m3fn:
assert b.elem_dtype == torch.float8_e4m3fn
Expand Down
Loading
Loading