Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions test/prototype/mx_formats/test_mx_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest
import torch

from torchao.utils import torch_version_at_least
from torchao.utils import is_sm_at_least_100, torch_version_at_least

if not torch_version_at_least("2.7.0"):
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
Expand All @@ -24,7 +24,10 @@
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from tqdm import tqdm

from torchao.prototype.moe_training.config import MXFP8TrainingOpConfig
from torchao.prototype.moe_training.config import (
MXFP8TrainingOpConfig,
MXFP8TrainingRecipe,
)
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.testing.training.dtensor_utils import (
_test_lowp_mlp_tensor_parallelism_base,
Expand All @@ -43,52 +46,64 @@ def setup_distributed():
return device_mesh


def _test_dtensor_cast_to_mxfp4(mesh: DeviceMesh, size=4):
def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=1024):
device = mesh.device_type

x_fp32 = torch.rand(size, size, device=device)
x_fp4 = MXTensor.to_mx(x_fp32, torch.float4_e2m1fn_x2, block_size=size // 2)
x_fp8 = MXTensor.to_mx(x_fp32, torch.float8_e4m3fn, block_size=32)

dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)])
dist_x_fp4 = MXTensor.to_mx(
dist_x_fp32, torch.float4_e2m1fn_x2, block_size=size // 2
)
assert isinstance(dist_x_fp4, DTensor)
dist_x_fp8 = MXTensor.to_mx(dist_x_fp32, torch.float8_e4m3fn, block_size=32)
assert isinstance(dist_x_fp8, DTensor)

# Verify that the result of to_mx with DTensor matches the slice of the
# result of to_mx without DTensor. This will fail on numeric op mismatches.
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
assert size % world_size == 0, "unsupported"
x_fp4_fp32 = x_fp4.dequantize(torch.float32)
x_fp8_fp32 = x_fp8.dequantize(torch.bfloat16)
rows_per_slice = size // world_size
slice_start = local_rank * rows_per_slice
slice_end = (local_rank + 1) * rows_per_slice
x_fp4_fp32_slice = x_fp4_fp32[slice_start:slice_end]
x_fp8_fp32_slice = x_fp8_fp32[slice_start:slice_end]
torch.testing.assert_close(
x_fp4_fp32_slice,
dist_x_fp4.to_local().dequantize(torch.float32),
x_fp8_fp32_slice,
dist_x_fp8.to_local().dequantize(torch.bfloat16),
atol=0,
rtol=0,
)


def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128):
config = MXFP8TrainingOpConfig()
def _test_mxfp8_mlp_tensor_parallelism_emulated(mesh: DeviceMesh, size=64):
recipe = MXFP8TrainingRecipe("mxfp8_emulated_rceil")
config = MXFP8TrainingOpConfig.from_recipe(recipe)
_test_lowp_mlp_tensor_parallelism_base(
mesh, config, size, compile=False, allgather_in_lowp=False
)
# _test_lowp_mlp_tensor_parallelism_base(
# mesh, config, size, compile=True, allgather_in_lowp=False
# )


def _test_mxfp8_mlp_tensor_parallelism_auto(mesh: DeviceMesh, size=64):
recipe = MXFP8TrainingRecipe("mxfp8_rceil")
config = MXFP8TrainingOpConfig.from_recipe(recipe)
_test_lowp_mlp_tensor_parallelism_base(
mesh, config, size, compile=True, allgather_in_lowp=False
mesh, config, size, compile=False, allgather_in_lowp=False
)
# _test_lowp_mlp_tensor_parallelism_base(
# mesh, config, size, compile=True, allgather_in_lowp=False
# )


if __name__ == "__main__":
device_mesh = setup_distributed()
tests = [
_test_dtensor_cast_to_mxfp4,
_test_mxfp8_mlp_tensor_parallelism,
# _test_dtensor_cast_to_mxfp8,
_test_mxfp8_mlp_tensor_parallelism_emulated,
]
if is_sm_at_least_100():
tests.append(_test_mxfp8_mlp_tensor_parallelism_auto)

for test in tqdm(tests, desc="Running tests"):
try:
Expand Down
57 changes: 44 additions & 13 deletions torchao/prototype/moe_training/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@
torch.ops.aten.clone.default,
torch.ops.aten.transpose.int,
torch.ops.aten.t.default,
# required for TP - scatter_ is used to distribute weights
torch.ops.c10d.scatter_.default,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this is the real fix right? Do we need the other ops you commented out?

# # required for quantizing weights
# torch.ops.aten.mul.Tensor,
# torch.ops.aten.abs.default,
# torch.ops.aten.amax.default,
# torch.ops.aten.clamp.default,
# torch.ops.aten.to.dtype,
# torch.ops.aten.unsqueeze.default,
# torch.ops.aten.div.Tensor,
# torch.ops.aten.reshape.default,
# torch.ops.aten.isnan.default,
# torch.ops.aten.log2.default,
# torch.ops.aten.where.default,
# torch.ops.aten.where.self,
# torch.ops.aten.ceil.default,
# torch.ops.aten.view.dtype,
# torch.ops.aten.squeeze.dim,
}


Expand Down Expand Up @@ -89,6 +107,7 @@ def __torch_function__(cls, func, types, args, kwargs={}):

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs={}):
print("[TORCH_DISPATCH]: ", func.__name__)
# unwrap args/kwargs and extract config
config = None

Expand Down Expand Up @@ -222,10 +241,6 @@ def __torch_function__(cls, func, types, args, kwargs={}):
# Use torchao scaled grouped mm with dynamic quant for
# "2d x 3d with offsets" case (used for routed experts).
# Otherwise, fall back to regular grouped mm.
#
# TODO: support "3d x 3d without offsets" case, which is
# used for shared experts. This is basically the grouped_mm
# kernel handling a bmm.
A, B = args[0], args[1]

assert not isinstance(A, cls), f"A should not be a {cls.__name__}"
Expand Down Expand Up @@ -263,14 +278,11 @@ class MXFP8TrainingWeightWrapperTensor(TrainingWeightWrapperBaseTensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs={}):
# grouped_mm op override
print("[TORCH_FUNCTION]", func.__name__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove these before landing?

if func.__name__ == "_grouped_mm":
# Use torchao scaled grouped mm with dynamic quant for
# "2d x 3d with offsets" case (used for routed experts).
# Otherwise, fall back to regular grouped mm.
#
# TODO: support "3d x 3d without offsets" case, which is
# used for shared experts. This is basically the grouped_mm
# kernel handling a bmm.
A, B = args[0], args[1]

assert not isinstance(A, cls), f"A should not be a {cls.__name__}"
Expand All @@ -285,26 +297,29 @@ def __torch_function__(cls, func, types, args, kwargs={}):
if A_is_2d and B_is_2d_or_3d and offs is not None:
return _quantize_then_scaled_grouped_mm(
A,
B,
unwrap_weight(B),
offs=offs,
config=config,
)

# linear op override
elif func.__name__ in ("linear", "mm", "matmul", "addmm"):
elif func.__name__ in ("linear", "mm", "mm.default"):
A, B = args[0], args[1]
assert not isinstance(A, cls), f"A should not be a {cls.__name__}"

assert not isinstance(A, cls), f"A should not be a {cls.__name__}"
assert isinstance(B, cls), f"B should be a {cls.__name__}"

config = B.config
assert isinstance(config, MXFP8TrainingOpConfig), (
"expected MXFP8TrainingOpConfig"
)


# Log weight shard statistics
weight = B._data

return _to_mxfp8_then_scaled_mm(
A,
B,
unwrap_weight(B),
kernel_preference=config.kernel_preference,
scale_calculation_mode=config.scale_calculation_mode,
wgrad_with_hp=config.wgrad_with_hp,
Expand All @@ -315,3 +330,19 @@ def __torch_function__(cls, func, types, args, kwargs={}):
# the wrapping behavior of the super() impl, go directly to dispatch
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)


class _UnwrapWeight(torch.autograd.Function):
"""Helper to unwrap the tensor subclass in a differentiable way"""

@staticmethod
def forward(ctx, wrapper_tensor):
return wrapper_tensor._data

@staticmethod
def backward(ctx, grad_output):
return grad_output


def unwrap_weight(wrapper_tensor):
return _UnwrapWeight.apply(wrapper_tensor)
2 changes: 1 addition & 1 deletion torchao/prototype/mx_formats/mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def forward(
scale_calculation_mode: ScaleCalculationMode,
wgrad_with_hp: bool,
):
print("mx_mm forward")
ctx.save_for_backward(input_hp, weight_hp)
ctx.in_elem_dtype = in_elem_dtype
ctx.w_elem_dtype = w_elem_dtype
Expand Down Expand Up @@ -137,7 +138,6 @@ def forward(
)
output = torch.mm(input_mx_r_dim0, weight_mx_dim0.t())
output = output.reshape(*input_orig_shape[:-1], output.shape[-1])

return output

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ def to_mx(
assert data_hp.shape[-1] % block_size == 0, (
f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}"
)
assert data_hp.is_contiguous(), "unsupported"
if not data_hp.is_contiguous():
assert data_hp.is_contiguous(), "unsupported"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bit confused by this, if it's not contiguous it would fail like before, so is there a reason behind this change?

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andrewor14 sorry i linked the wrong PR, this is not the one that will address test failures, this is a WIP draft for an issue we are still trying to find a proper solution for - please disregard this PR.

assert elem_dtype in SUPPORTED_ELEM_DTYPES, "unsupported"

orig_shape = data_hp.shape
Expand Down
89 changes: 66 additions & 23 deletions torchao/testing/training/dtensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from torchao.prototype.moe_training.config import MXFP8TrainingOpConfig
from torchao.quantization import quantize_
from torchao.quantization.utils import compute_error


class FeedForward(nn.Module):
Expand Down Expand Up @@ -67,7 +68,7 @@ def _test_lowp_mlp_tensor_parallelism_base(
if isinstance(config, MXFP8TrainingOpConfig):
convert_model_func = quantize_

toy_model = ToyModel(size).to(device)
toy_model = ToyModel(size).to(device).to(torch.bfloat16)
toy_model_fp8 = copy.deepcopy(toy_model)
convert_model_func(toy_model_fp8, config=config)

Expand Down Expand Up @@ -151,32 +152,74 @@ def _test_lowp_mlp_tensor_parallelism_base(
sp_model = torch.compile(sp_model)
sp_model2 = torch.compile(sp_model2)

x_fp32 = torch.rand(2, size * 2, size, device=device, requires_grad=False)
go_fp32 = torch.rand(2, size * 2, size, device=device, requires_grad=False)
x_fp32_tp_input = x_fp32.clone()
go_fp32_tp = go_fp32.clone()
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])
go_fp32_sp = distribute_tensor(go_fp32.clone(), mesh, [Shard(0)])

tp_out = tp_model(x_fp32_tp_input)
tp_out.backward(go_fp32_tp)
sp_out = sp_model(x_fp32_sp_input)
sp_out.backward(go_fp32_sp)
global_out = toy_model_fp8(x_fp32)
global_out.backward(go_fp32)
torch.testing.assert_close(tp_out, global_out)
torch.testing.assert_close(sp_out.full_tensor(), global_out)
torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad)
torch.testing.assert_close(
x_bf16 = torch.rand(
2, size * 2, size, device=device, requires_grad=False, dtype=torch.bfloat16
)
go_bf16 = torch.rand(
2, size * 2, size, device=device, requires_grad=False, dtype=torch.bfloat16
)
x_bf16_tp_input = x_bf16.clone()
go_bf16_tp = go_bf16.clone()
x_bf16_sp_input = distribute_tensor(x_bf16.clone(), mesh, [Shard(0)])
go_bf16_sp = distribute_tensor(go_bf16.clone(), mesh, [Shard(0)])

tp_out = tp_model(x_bf16_tp_input)
tp_out.backward(go_bf16_tp)
sp_out = sp_model(x_bf16_sp_input)
sp_out.backward(go_bf16_sp)
global_out = toy_model_fp8(x_bf16)
global_out.backward(go_bf16)

MIN_SQNR = 23.0

if not torch.allclose(tp_out, global_out):
print(
f"tp out comparison not close, shapes: tp={tp_out.shape}, global={global_out.shape}"
)
print(
f"tp_out stats: min={tp_out.min()}, max={tp_out.max()}, mean={tp_out.mean()}"
)
print(
f"global_out stats: min={global_out.min()}, max={global_out.max()}, mean={global_out.mean()}"
)
diff = (tp_out - global_out).abs()
print(f"diff stats: min={diff.min()}, max={diff.max()}, mean={diff.mean()}")

tp_out_sqnr = compute_error(tp_out, global_out)
print(f"tp_out SQNR: {tp_out_sqnr}")
assert tp_out_sqnr >= MIN_SQNR, f"tp_out SQNR {tp_out_sqnr} < {MIN_SQNR}"

sp_out_sqnr = compute_error(sp_out.full_tensor(), global_out)
assert sp_out_sqnr >= MIN_SQNR, f"sp_out SQNR {sp_out_sqnr} < {MIN_SQNR}"

w1_grad_sqnr = compute_error(
tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad
)
assert w1_grad_sqnr >= MIN_SQNR, f"w1.weight.grad SQNR {w1_grad_sqnr} < {MIN_SQNR}"

out_proj_grad_sqnr = compute_error(
tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad
)
assert out_proj_grad_sqnr >= MIN_SQNR, (
f"out_proj.weight.grad SQNR {out_proj_grad_sqnr} < {MIN_SQNR}"
)

sp_out2 = sp_model2(x_bf16_sp_input)
sp_out2.backward(go_bf16_sp)

sp_out2 = sp_model2(x_fp32_sp_input)
sp_out2.backward(go_fp32_sp)
torch.testing.assert_close(sp_out2.full_tensor(), global_out)
torch.testing.assert_close(
sp_out2_sqnr = compute_error(sp_out2.full_tensor(), global_out)
assert sp_out2_sqnr >= MIN_SQNR, f"sp_out2 SQNR {sp_out2_sqnr} < {MIN_SQNR}"

w1_grad2_sqnr = compute_error(
tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad
)
torch.testing.assert_close(
assert w1_grad2_sqnr >= MIN_SQNR, (
f"w1.weight.grad (sp_model2) SQNR {w1_grad2_sqnr} < {MIN_SQNR}"
)

out_proj_grad2_sqnr = compute_error(
tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad
)
assert out_proj_grad2_sqnr >= MIN_SQNR, (
f"out_proj.weight.grad (sp_model2) SQNR {out_proj_grad2_sqnr} < {MIN_SQNR}"
)
Loading