diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index 48e43c5dee..ef1e59e41f 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -84,6 +84,12 @@ def __init__( "torch.compile enablement is required for highest performance of MXFP8 dynamic quantization." ) + if parallel_dims.tp > 1: + logger.warning( + "TP support for MXFP8 linears is still in progress. " + "Any linear layers with TP applied will use default precision, not MXFP8." + ) + self.config = config self.enabled = True logger.info("MXFP8 MoE training enabled")