diff --git a/test/prototype/mx_formats/test_mx_dtensor.py b/test/prototype/mx_formats/test_mx_dtensor.py index 9939fcaf80..e1c9358201 100644 --- a/test/prototype/mx_formats/test_mx_dtensor.py +++ b/test/prototype/mx_formats/test_mx_dtensor.py @@ -24,7 +24,10 @@ from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from tqdm import tqdm -from torchao.prototype.moe_training.config import MXFP8TrainingOpConfig +from torchao.prototype.moe_training.config import ( + MXFP8TrainingOpConfig, + MXFP8TrainingRecipe, +) from torchao.prototype.mx_formats.mx_tensor import MXTensor from torchao.testing.training.dtensor_utils import ( _test_lowp_mlp_tensor_parallelism_base, @@ -43,52 +46,80 @@ def setup_distributed(): return device_mesh -def _test_dtensor_cast_to_mxfp4(mesh: DeviceMesh, size=4): +def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=1024): device = mesh.device_type x_fp32 = torch.rand(size, size, device=device) - x_fp4 = MXTensor.to_mx(x_fp32, torch.float4_e2m1fn_x2, block_size=size // 2) + x_fp8 = MXTensor.to_mx(x_fp32, torch.float8_e4m3fn, block_size=32) dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)]) - dist_x_fp4 = MXTensor.to_mx( - dist_x_fp32, torch.float4_e2m1fn_x2, block_size=size // 2 + dist_x_fp8 = MXTensor.to_mx(dist_x_fp32, torch.float8_e4m3fn, block_size=32) + + # With the new wrapping order, MXTensor is the outer wrapper with DTensor + # inner tensors (MXTensor(DTensor_qdata, DTensor_scale)). + assert isinstance(dist_x_fp8, MXTensor), ( + f"Expected MXTensor, got {type(dist_x_fp8)}" + ) + assert isinstance(dist_x_fp8.qdata, DTensor), ( + f"Expected qdata to be DTensor, got {type(dist_x_fp8.qdata)}" + ) + assert isinstance(dist_x_fp8.scale, DTensor), ( + f"Expected scale to be DTensor, got {type(dist_x_fp8.scale)}" ) - assert isinstance(dist_x_fp4, DTensor) # Verify that the result of to_mx with DTensor matches the slice of the # result of to_mx without DTensor. This will fail on numeric op mismatches. local_rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() assert size % world_size == 0, "unsupported" - x_fp4_fp32 = x_fp4.dequantize(torch.float32) + x_fp8_fp32 = x_fp8.dequantize(torch.bfloat16) rows_per_slice = size // world_size slice_start = local_rank * rows_per_slice slice_end = (local_rank + 1) * rows_per_slice - x_fp4_fp32_slice = x_fp4_fp32[slice_start:slice_end] + x_fp8_fp32_slice = x_fp8_fp32[slice_start:slice_end] + # dequantize handles DTensor inner tensors and returns a DTensor + dist_x_fp8_dequant = dist_x_fp8.dequantize(torch.bfloat16) + assert isinstance(dist_x_fp8_dequant, DTensor), ( + f"Expected dequantize result to be DTensor, got {type(dist_x_fp8_dequant)}" + ) torch.testing.assert_close( - x_fp4_fp32_slice, - dist_x_fp4.to_local().dequantize(torch.float32), + x_fp8_fp32_slice, + dist_x_fp8_dequant.to_local(), atol=0, rtol=0, ) -def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128): - config = MXFP8TrainingOpConfig() +def _test_mxfp8_mlp_tensor_parallelism_emulated(mesh: DeviceMesh, size=64): + recipe = MXFP8TrainingRecipe("mxfp8_emulated_rceil") + config = MXFP8TrainingOpConfig.from_recipe(recipe) _test_lowp_mlp_tensor_parallelism_base( mesh, config, size, compile=False, allgather_in_lowp=False ) + + +def _test_mxfp8_mlp_tensor_parallelism_auto(mesh: DeviceMesh, size=64): + recipe = MXFP8TrainingRecipe("mxfp8_rceil") + config = MXFP8TrainingOpConfig.from_recipe(recipe) _test_lowp_mlp_tensor_parallelism_base( - mesh, config, size, compile=True, allgather_in_lowp=False + mesh, config, size, compile=False, allgather_in_lowp=False ) if __name__ == "__main__": device_mesh = setup_distributed() tests = [ - _test_dtensor_cast_to_mxfp4, - _test_mxfp8_mlp_tensor_parallelism, + _test_dtensor_cast_to_mxfp8, + _test_mxfp8_mlp_tensor_parallelism_emulated, ] + from torchao.prototype.moe_training.kernels.mxfp8.quant import ( + _mxfp8_cuda_kernels_available, + ) + + if _mxfp8_cuda_kernels_available: + tests.append(_test_mxfp8_mlp_tensor_parallelism_auto) + else: + print("Skipping auto test: requires SM >= 100 and CUDA >= 12.8") for test in tqdm(tests, desc="Running tests"): try: diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index d1aa36e508..1c734fa847 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -43,6 +43,8 @@ torch.ops.aten.clone.default, torch.ops.aten.transpose.int, torch.ops.aten.t.default, + # required for TP - scatter_ is used to distribute weights + torch.ops.c10d.scatter_.default, } @@ -288,7 +290,7 @@ def __torch_function__(cls, func, types, args, kwargs={}): if A_is_2d and B_is_2d_or_3d and offs is not None: return _quantize_then_scaled_grouped_mm( A, - B, + unwrap_weight(B), offs=offs, config=config, ) @@ -296,8 +298,8 @@ def __torch_function__(cls, func, types, args, kwargs={}): # linear op override elif func.__name__ in ("linear", "mm", "matmul", "addmm"): A, B = args[0], args[1] - assert not isinstance(A, cls), f"A should not be a {cls.__name__}" + assert not isinstance(A, cls), f"A should not be a {cls.__name__}" assert isinstance(B, cls), f"B should be a {cls.__name__}" config = B.config @@ -307,7 +309,7 @@ def __torch_function__(cls, func, types, args, kwargs={}): return _to_mxfp8_then_scaled_mm( A, - B, + unwrap_weight(B), kernel_preference=config.kernel_preference, scale_calculation_mode=config.scale_calculation_mode, wgrad_with_hp=config.wgrad_with_hp, diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 74cf8397da..b7459628d0 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1122,35 +1122,41 @@ def custom_mxfp8_quantize_cuda_dim1_sharding( fp8_format: str, scaling_mode: str, ): - # This function signature can be used to understand the shardings: - # _, colwise_data, _, colwise_scales = mxfp8_quantize_cuda(x, rowwise=False, colwise=True) - - # When inputs and scale are replicated, we return a quantized output tensor (replicated). - inputs_replicated = [None, Replicate(), None, Replicate()] - outputs_replicated = [None, Replicate(), None, None] - rule_for_input_replicated = ( - inputs_replicated, - outputs_replicated, + # Op returns 4 tensors: (output_rowwise, output_colwise, scales_rowwise, scales_colwise) + # When rowwise=False, outputs 0 and 2 are empty tensors (size 0). + # output_colwise has shape (rows, cols) in col-major order. + # scales_colwise has shape (cols, num_row_blocks) in col-major order. + # + # Format: (output_placements, input_placements) + # Input placements: one per arg (x=Tensor, then 6 non-tensor args=None) + # Output placements: one per output tensor (4 total) + + non_tensor_args = [None, None, None, None, None, None] + + # When input is replicated, all outputs are replicated. + rule_replicated = ( + [Replicate(), Replicate(), Replicate(), Replicate()], + [Replicate()] + non_tensor_args, ) - # When inputs and scale are sharded along dim 0, - # we return a quantized output tensor (sharded along dim1 due to transpose). - inputs_sharded_dim0 = [None, Shard(0), None, Shard(0)] - outputs_sharded_dim1 = [None, Shard(1), None, None] - rule_for_input_sharded_dim0 = (inputs_sharded_dim0, outputs_sharded_dim1) - - # When inputs and scale are sharded along dim 1, - # we return a quantized output tensor (sharded along dim0 due to transpose). - inputs_sharded_dim1 = [None, Shard(1), None, Shard(1)] - outputs_sharded_dim0 = [None, Shard(0), None, None] - rule_for_input_sharded_dim1 = (inputs_sharded_dim1, outputs_sharded_dim0) - - acceptable_shardings = [ - rule_for_input_replicated, - rule_for_input_sharded_dim0, - rule_for_input_sharded_dim1, - ] - return acceptable_shardings + # When input is sharded along dim 0: + # output_colwise (rows, cols) col-major: rows are sharded → Shard(0) + # scales_colwise (cols, num_row_blocks) col-major: row blocks sharded → Shard(1) + # Unused rowwise outputs (empty tensors): Replicate() + rule_shard_dim0 = ( + [Replicate(), Shard(0), Replicate(), Shard(1)], + [Shard(0)] + non_tensor_args, + ) + + # When input is sharded along dim 1: + # output_colwise: cols are sharded → Shard(1) + # scales_colwise: col dim is sharded → Shard(0) + rule_shard_dim1 = ( + [Replicate(), Shard(1), Replicate(), Shard(0)], + [Shard(1)] + non_tensor_args, + ) + + return [rule_replicated, rule_shard_dim0, rule_shard_dim1] else: diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index c693156f59..90422f967e 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -23,7 +23,8 @@ import torch import torch.nn.functional as F -from torch.distributed._tensor import DTensor +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.distributed.tensor.experimental import local_map from torch.utils._python_dispatch import ( return_and_correct_aliasing, ) @@ -353,7 +354,7 @@ def get_fp_scale(scale_e8m0): s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS # TODO(later): it would be nice if there was a way to do the 2^x operation # in PyTorch without creating a tensor of twos - two = torch.full(s_offset.size(), 2.0, device=scale_e8m0.device) + two = torch.full_like(s_offset, 2.0, dtype=torch.float32) # pow(two, s_offset) can be out of range of floating point formats. # TODO(later): handle this for float16 if we decide to support float16 # scales. @@ -560,35 +561,6 @@ def from_qdata_and_scales( ) elem_dtype = qdata.dtype - if isinstance(qdata, DTensor) or isinstance(scales, DTensor): - assert isinstance(qdata, DTensor) and isinstance(scales, DTensor), ( - "qdata and scales must either both be DTensors or both be local tensors" - ) - assert qdata.device_mesh == scales.device_mesh, ( - "qdata and scales DTensors must have the same device mesh" - ) - assert qdata.placements == scales.placements, ( - "qdata and scales DTensors must have the same placements" - ) - inner_mx_tensor = MXTensor( - qdata.to_local(), - scales.to_local(), - elem_dtype, - block_size, - orig_dtype, - kernel_preference, - act_quant_kwargs, - is_swizzled_scales, - ) - return DTensor.from_local( - inner_mx_tensor, - qdata.device_mesh, - qdata.placements, - run_check=False, - shape=qdata.size(), - stride=qdata.stride(), - ) - return MXTensor( qdata, scales, @@ -640,28 +612,6 @@ def to_mx( inner_block_size=block_size, scaling_mode=scaling_mode.value, ) - if isinstance(scale_e8m0_biased, DTensor): - assert isinstance(data_lp, DTensor), "unsupported" - local_scale_e8m0_biased = scale_e8m0_biased.to_local() - local_data_lp = data_lp.to_local() - inner_mx_tensor = MXTensor( - local_data_lp, - local_scale_e8m0_biased, - elem_dtype, - block_size, - data_hp.dtype, - kernel_preference, - act_quant_kwargs, - is_swizzled_scales, - ) - return DTensor.from_local( - inner_mx_tensor, - data_lp.device_mesh, - data_lp.placements, - run_check=False, - shape=data_lp.size(), - stride=data_lp.stride(), - ) return MXTensor( data_lp, scale_e8m0_biased, @@ -703,6 +653,29 @@ def _get_gemm_choice( return choice_a if choice_a is not None else choice_b +def maybe_dtensor_to_blocked(t: torch.Tensor) -> torch.Tensor: + # redistribute to Replicate or Shard(0); to_blocked will view/permute/flatten into a 1d tensor + # sharding is only preservable on the first dimension. + if isinstance(t, DTensor): + t_placements = [ + x if x in (Replicate(), Shard(0)) else Replicate() for x in t.placements + ] + if t_placements != t.placements: # can't perform collectives in float8 + t = ( + t.view(torch.uint8) + .redistribute(placements=t_placements) + .view(torch.float8_e8m0fnu) + ) + out = local_map( + to_blocked, + in_placements=(t_placements,), + out_placements=t_placements, + )(t) + else: + out = to_blocked(t) + return out + + def _addmm_mx_dispatch( a: torch.Tensor, b: MXTensor, aten_op, bias: Optional[torch.Tensor] = None ) -> torch.Tensor: @@ -737,13 +710,13 @@ def _addmm_mx_dispatch( a_scale_block = a.scale else: a_scale = a.scale.view(M, K // a.block_size) - a_scale_block = to_blocked(a_scale) + a_scale_block = maybe_dtensor_to_blocked(a_scale) if b.is_swizzled_scales: b_scale_block = b.scale.t() else: b_scale = b.scale.t().view(N, K // b.block_size) - b_scale_block = to_blocked(b_scale) + b_scale_block = maybe_dtensor_to_blocked(b_scale) if a.elem_dtype == torch.float8_e4m3fn: assert b.elem_dtype == torch.float8_e4m3fn diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index b9c832852c..2a4052ac60 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -8,7 +8,6 @@ from typing import Tuple import torch -from torch.distributed._tensor import DTensor from torchao.prototype.mx_formats.config import ( MXFP8Dim1CastKernelChoice, @@ -192,39 +191,18 @@ def _to_mxfp8_dim1_kernel_wrapper( raise ValueError(f"must be one of [CUDA, TRITON], got {cast_kernel_choice}") is_swizzled_scales = False - if isinstance(a_data, DTensor): - assert isinstance(a_scale, DTensor) - a_data_local = a_data.to_local() - a_scale_local = a_scale.to_local() - inner = MXTensor( - a_data_local.t(), - a_scale_local, - elem_dtype, - block_size, - hp_dtype, - kernel_preference, - None, - is_swizzled_scales, - ) - mx_tensor = DTensor.from_local( - inner, - a_data.device_mesh, - a_data.placements, - run_check=False, - shape=a_data.t().size(), - stride=a_data.t().stride(), - ) - else: - mx_tensor = MXTensor( - a_data.t(), - a_scale, - elem_dtype, - block_size, - hp_dtype, - kernel_preference, - None, - is_swizzled_scales, - ) + # MXTensor wraps DTensor inner tensors directly (MXTensor(DTensor) ordering). + # DTensor's .t() handles placement transposition automatically. + mx_tensor = MXTensor( + a_data.t(), + a_scale, + elem_dtype, + block_size, + hp_dtype, + kernel_preference, + None, + is_swizzled_scales, + ) return mx_tensor diff --git a/torchao/testing/training/dtensor_utils.py b/torchao/testing/training/dtensor_utils.py index 75c8f637c6..4536386c09 100644 --- a/torchao/testing/training/dtensor_utils.py +++ b/torchao/testing/training/dtensor_utils.py @@ -64,18 +64,18 @@ def _test_lowp_mlp_tensor_parallelism_base( # TODO(future): remove this once float8 training works with `quantize_` API convert_model_func = convert_to_float8_training - if isinstance(config, MXFP8TrainingOpConfig): + is_mxfp8 = isinstance(config, MXFP8TrainingOpConfig) + if is_mxfp8: convert_model_func = quantize_ toy_model = ToyModel(size).to(device) + if is_mxfp8: + toy_model = toy_model.to(torch.bfloat16) + + # Non-TP reference model toy_model_fp8 = copy.deepcopy(toy_model) convert_model_func(toy_model_fp8, config=config) - tp_model = copy.deepcopy(toy_model) - convert_model_func(tp_model, config=config) - sp_model = copy.deepcopy(toy_model) - convert_model_func(sp_model, config=config) - # For tensorwise scaling, enable float8 all_gather. # For rowwise scaling, keep high precision all_gather. Motivation for # not doing float8 all-gather for rowwise: tensors need to be scaled both ways, @@ -90,7 +90,17 @@ def _test_lowp_mlp_tensor_parallelism_base( rowwise_parallel_cls = Float8RowwiseParallel prepare_input_cls = PrepareFloat8ModuleInput + # For MXFP8: parallelize first, then quantize. + # This puts MXFP8 wrapper on top of DTensor so __torch_function__ + # intercepts F.linear before DTensor can trigger premature all-gathers. + # + # For Float8: quantize first, then parallelize (original behavior). + # Float8 TP strategies (Float8ColwiseParallel etc.) expect Float8 weights. + # vanilla TP + tp_model = copy.deepcopy(toy_model) + if not is_mxfp8: + convert_model_func(tp_model, config=config) tp_model = parallelize_module( tp_model, mesh, @@ -100,8 +110,13 @@ def _test_lowp_mlp_tensor_parallelism_base( "ffn.out_proj": rowwise_parallel_cls(), }, ) + if is_mxfp8: + convert_model_func(tp_model, config=config) # "sequence parallel" mlp computation + sp_model = copy.deepcopy(toy_model) + if not is_mxfp8: + convert_model_func(sp_model, config=config) sp_model = parallelize_module( sp_model, mesh, @@ -116,10 +131,13 @@ def _test_lowp_mlp_tensor_parallelism_base( ), }, ) + if is_mxfp8: + convert_model_func(sp_model, config=config) # prepare_input_cls with specific submodule fqn sp_model2 = copy.deepcopy(toy_model) - convert_model_func(sp_model2, config=config) + if not is_mxfp8: + convert_model_func(sp_model2, config=config) if not allgather_in_lowp: prepare_input = prepare_input_cls( @@ -145,38 +163,50 @@ def _test_lowp_mlp_tensor_parallelism_base( ), }, ) + if is_mxfp8: + convert_model_func(sp_model2, config=config) if compile: tp_model = torch.compile(tp_model) sp_model = torch.compile(sp_model) sp_model2 = torch.compile(sp_model2) - x_fp32 = torch.rand(2, size * 2, size, device=device, requires_grad=False) - go_fp32 = torch.rand(2, size * 2, size, device=device, requires_grad=False) - x_fp32_tp_input = x_fp32.clone() - go_fp32_tp = go_fp32.clone() - x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)]) - go_fp32_sp = distribute_tensor(go_fp32.clone(), mesh, [Shard(0)]) - - tp_out = tp_model(x_fp32_tp_input) - tp_out.backward(go_fp32_tp) - sp_out = sp_model(x_fp32_sp_input) - sp_out.backward(go_fp32_sp) - global_out = toy_model_fp8(x_fp32) - global_out.backward(go_fp32) + input_dtype = torch.bfloat16 if is_mxfp8 else torch.float32 + x = torch.rand( + 2, size * 2, size, device=device, requires_grad=False, dtype=input_dtype + ) + go = torch.rand( + 2, size * 2, size, device=device, requires_grad=False, dtype=input_dtype + ) + x_tp_input = x.clone() + go_tp = go.clone() + x_sp_input = distribute_tensor(x.clone(), mesh, [Shard(0)]) + go_sp = distribute_tensor(go.clone(), mesh, [Shard(0)]) + + tp_out = tp_model(x_tp_input) + tp_out.backward(go_tp) + + sp_out = sp_model(x_sp_input) + sp_out.backward(go_sp) + + global_out = toy_model_fp8(x) + global_out.backward(go) + torch.testing.assert_close(tp_out, global_out) torch.testing.assert_close(sp_out.full_tensor(), global_out) torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad) torch.testing.assert_close( - tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad + tp_model.ffn.out_proj.weight.grad, + sp_model.ffn.out_proj.weight.grad, ) - sp_out2 = sp_model2(x_fp32_sp_input) - sp_out2.backward(go_fp32_sp) + sp_out2 = sp_model2(x_sp_input) + sp_out2.backward(go_sp) torch.testing.assert_close(sp_out2.full_tensor(), global_out) torch.testing.assert_close( tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad ) torch.testing.assert_close( - tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad + tp_model.ffn.out_proj.weight.grad, + sp_model2.ffn.out_proj.weight.grad, )