Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.
import unittest

import pytest
import torch
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
from torch.testing._internal import common_utils
Expand Down Expand Up @@ -36,9 +35,6 @@
except ModuleNotFoundError:
has_gemlite = False

if torch.version.hip is not None:
pytest.skip("Skipping the test in ROCm", allow_module_level=True)


class TestAffineQuantizedTensorParallel(DTensorTestBase):
"""Basic test case for tensor subclasses"""
Expand Down Expand Up @@ -196,8 +192,10 @@ def test_tp(self, dtype):
common_utils.instantiate_parametrized_tests(TestGemliteLayoutTensorParallel)
common_utils.instantiate_parametrized_tests(TestInt8dqAffineQuantizedTensorParallel)

# Run only on H100
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):
# Float8 TP requires FP8-capable hardware (H100+ on CUDA, MI300+ on ROCm)
from torchao.utils import is_MI300, is_MI350, is_sm_at_least_90

if torch.cuda.is_available() and (is_sm_at_least_90() or is_MI300() or is_MI350()):

class TestFloat8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
QUANT_METHOD_FN = staticmethod(Float8WeightOnlyConfig)
Expand Down
21 changes: 13 additions & 8 deletions test/float8/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@
TransformerBlock,
)

from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
ScalingType,
e4m3_dtype,
)
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.float8_training_tensor import GemmInputRole
Expand All @@ -40,13 +45,13 @@
check_parity_bf16_mp,
check_parity_no_mp,
)
from torchao.utils import is_sm_at_least_89

if not is_sm_at_least_89():
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)
from torchao.utils import is_MI300, is_MI350, is_sm_at_least_89

if torch.version.hip is not None:
pytest.skip("ROCm enablement in progress", allow_module_level=True)
if not (is_sm_at_least_89() or is_MI300() or is_MI350()):
pytest.skip(
"Requires FP8-capable GPU (CUDA SM89+, MI300, or MI350)",
allow_module_level=True,
)


class TestFloat8Common:
Expand Down Expand Up @@ -336,7 +341,7 @@ def test_amax_allreduce_device_mesh(self):
hp_tensor = torch.randn(768, 32, device="cuda")
hp_tensor_to_float8_dynamic(
hp_tensor,
torch.float8_e4m3fn,
e4m3_dtype,
Float8LinearConfig(
cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC),
),
Expand Down
10 changes: 6 additions & 4 deletions test/prototype/mx_formats/test_mx_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
"""

import os
import sys

import pytest
import torch

from torchao.utils import torch_version_at_least
# TODO: re-enable once mx training refactor is complete
_SKIP_MSG = "DTensor support incomplete, MXFP8 training refactor is not yet complete, see: https://github.com/pytorch/ao/pull/3985"

if not torch_version_at_least("2.7.0"):
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
if __name__ == "__main__":
print(f"SKIPPED: {_SKIP_MSG}")
sys.exit(0)

from torch.distributed._tensor import DTensor, Shard, distribute_tensor
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
Expand Down
Loading