Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Reland] ROCm CI (Infra + Skips) #1581

Merged
merged 33 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
afb1d72
skip failing unit tests for ROCm CI
petrex Jan 14, 2025
ec02df5
update skip_if_rocm import
petrex Jan 16, 2025
d6ec43f
Enable ROCM in CI
msaroufim Oct 3, 2024
5c727e6
Update regression_test.yml
amdfaa Jan 16, 2025
a541d60
Update regression_test.yml
amdfaa Jan 16, 2025
5e46cba
Update regression_test.yml
amdfaa Jan 17, 2025
467236f
lint
petrex Jan 21, 2025
6d6f203
Enable ROCM in CI (#999)
msaroufim Jan 17, 2025
69db090
Update regression_test.yml
amdfaa Jan 22, 2025
7122221
Update regression_test.yml
amdfaa Jan 22, 2025
74887fa
skip ROCm tests
petrex Jan 22, 2025
46b0caf
skip rocm tests
petrex Jan 22, 2025
8b43a08
skip fsdp2 test for ROCm
petrex Jan 22, 2025
da45960
Update regression_test.yml
amdfaa Jan 23, 2025
7a267b8
skip smooth quant test (torch dynamo)
petrex Jan 24, 2025
f988edf
skip nf4 tests
petrex Jan 24, 2025
3168159
skip test for uneven shard
petrex Jan 28, 2025
01bab42
skip test low bit optim
petrex Jan 28, 2025
a7a021d
Update regression_test.yml
amdfaa Feb 4, 2025
09c0f8c
fix auto-merge
petrex Feb 5, 2025
387d321
Update regression_test.yml
amdfaa Feb 7, 2025
0b83758
Update regression_test.yml
amdfaa Feb 11, 2025
bef9d17
Update regression_test.yml
amdfaa Feb 17, 2025
fff25bd
Update regression_test.yml
amdfaa Feb 20, 2025
8426b7a
Merge branch 'main' into skipROCmTest
jithunnair-amd Feb 20, 2025
14bd4cc
Update test_ops.py
jithunnair-amd Feb 20, 2025
127b445
Attempt to disable only ROCm matrix entries for non-push-to-main
jithunnair-amd Feb 20, 2025
61e86c2
Attempt to disable only ROCm matrix entries for non-push-to-main - 2
jithunnair-amd Feb 20, 2025
a6958d7
Add new regression_test_rocm.yml as per upstream recommendation
jithunnair-amd Feb 21, 2025
75e0058
Ruff fixes
jithunnair-amd Feb 21, 2025
e6ecd1f
Add skip_if_rocm decorator to test_workflow_e2e_numerics
petrex Feb 21, 2025
27d0d48
lint
petrex Feb 21, 2025
900cf5b
Add skip_if_rocm decorator to test_float8_utils
petrex Feb 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,19 @@ jobs:
torch-spec: '--pre torch==2.7.0.dev20250122 --index-url https://download.pytorch.org/whl/nightly/cpu'
gpu-arch-type: "cpu"
gpu-arch-version: ""

- name: ROCM Nightly
runs-on: linux.rocm.gpu.torchao
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.3'
gpu-arch-type: "rocm"
gpu-arch-version: "6.3"
permissions:
id-token: write
contents: read
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@enable_linux_job_v2
with:
enabled: ${{ matrix.gpu-arch-type != 'rocm' || (github.event_name == 'push' && startsWith(github.ref, 'refs/heads/main')) }}
timeout: 120
no-sudo: ${{ matrix.gpu-arch-type == 'rocm' }}
runner: ${{ matrix.runs-on }}
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
Expand Down Expand Up @@ -74,7 +80,6 @@ jobs:
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"

- name: CPU 2.3
runs-on: linux.4xlarge
torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu'
Expand Down Expand Up @@ -102,8 +107,6 @@ jobs:
conda create -n venv python=3.9 -y
conda activate venv
echo "::group::Install newer objcopy that supports --set-section-alignment"
yum install -y devtoolset-10-binutils
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
python -m pip install --upgrade pip
pip install ${{ matrix.torch-spec }}
pip install -r dev-requirements.txt
Expand Down
4 changes: 4 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
TORCH_VERSION_AT_LEAST_2_6,
is_fbcode,
is_sm_at_least_89,
skip_if_rocm,
)

is_cusparselt_available = (
Expand Down Expand Up @@ -104,6 +105,7 @@ def test_tensor_core_layout_transpose(self):
"apply_quant",
get_quantization_functions(is_cusparselt_available, True, "cuda", True),
)
@skip_if_rocm("ROCm enablement in progress")
def test_weights_only(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
if isinstance(apply_quant, AOBaseConfig):
Expand Down Expand Up @@ -196,6 +198,7 @@ def apply_uint6_weight_only_quant(linear):
"apply_quant", get_quantization_functions(is_cusparselt_available, True)
)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@skip_if_rocm("ROCm enablement in progress")
def test_print_quantized_module(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
if isinstance(apply_quant, AOBaseConfig):
Expand All @@ -213,6 +216,7 @@ class TestAffineQuantizedBasic(TestCase):

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
@skip_if_rocm("ROCm enablement in progress")
def test_flatten_unflatten(self, device, dtype):
if device == "cuda" and dtype == torch.bfloat16 and is_fbcode():
raise unittest.SkipTest("TODO: Failing for cuda + bfloat16 in fbcode")
Expand Down
4 changes: 4 additions & 0 deletions test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest

import pytest
import torch
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
from torch.testing._internal import common_utils
Expand Down Expand Up @@ -27,6 +28,9 @@
except ModuleNotFoundError:
has_gemlite = False

if torch.version.hip is not None:
pytest.skip("Skipping the test in ROCm", allow_module_level=True)


class TestAffineQuantizedTensorParallel(DTensorTestBase):
"""Basic test case for tensor subclasses"""
Expand Down
3 changes: 2 additions & 1 deletion test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
fpx_weight_only,
quantize_,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode, skip_if_rocm

_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
_Floatx_DTYPES = [(3, 2), (2, 2)]
Expand Down Expand Up @@ -109,6 +109,7 @@ def test_to_copy_device(self, ebits, mbits):
@parametrize("bias", [False, True])
@parametrize("dtype", [torch.half, torch.bfloat16])
@unittest.skipIf(is_fbcode(), reason="broken in fbcode")
@skip_if_rocm("ROCm enablement in progress")
def test_fpx_weight_only(self, ebits, mbits, bias, dtype):
N, OC, IC = 4, 256, 64
device = "cuda"
Expand Down
3 changes: 3 additions & 0 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
nf4_weight_only,
to_nf4,
)
from torchao.utils import skip_if_rocm

bnb_available = False

Expand Down Expand Up @@ -111,6 +112,7 @@ def test_backward_dtype_match(self, dtype: torch.dtype):

@unittest.skipIf(not bnb_available, "Need bnb availble")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@skip_if_rocm("ROCm enablement in progress")
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype):
# From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65C1-L81C47
Expand All @@ -133,6 +135,7 @@ def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype):

@unittest.skipIf(not bnb_available, "Need bnb availble")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@skip_if_rocm("ROCm enablement in progress")
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_nf4_bnb_linear(self, dtype: torch.dtype):
"""
Expand Down
4 changes: 3 additions & 1 deletion test/dtypes/test_uint4.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, skip_if_rocm


def _apply_weight_only_uint4_quant(model):
Expand Down Expand Up @@ -92,6 +92,7 @@ def test_basic_tensor_ops(self):
# only test locally
# print("x:", x[0])

@skip_if_rocm("ROCm enablement in progress")
def test_gpu_quant(self):
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]:
x = torch.randn(*x_shape)
Expand All @@ -104,6 +105,7 @@ def test_gpu_quant(self):
# make sure it runs
opt(x)

@skip_if_rocm("ROCm enablement in progress")
def test_pt2e_quant(self):
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
QuantizationConfig,
Expand Down
2 changes: 2 additions & 0 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
TORCH_VERSION_AT_LEAST_2_5,
is_sm_at_least_89,
is_sm_at_least_90,
skip_if_rocm,
)

if not TORCH_VERSION_AT_LEAST_2_5:
Expand Down Expand Up @@ -426,6 +427,7 @@ def test_linear_from_config_params(
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize("linear_bias", [True, False])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skip_if_rocm("ROCm enablement in progress")
def test_linear_from_recipe(
self,
recipe_name,
Expand Down
3 changes: 3 additions & 0 deletions test/float8/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
if not is_sm_at_least_89():
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)

if torch.version.hip is not None:
pytest.skip("ROCm enablement in progress", allow_module_level=True)


class TestFloat8Common:
def broadcast_module(self, module: nn.Module) -> None:
Expand Down
2 changes: 2 additions & 0 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
skip_if_rocm,
)

cuda_available = torch.cuda.is_available()
Expand Down Expand Up @@ -109,6 +110,7 @@ def test_hqq_plain_5bit(self):
ref_dot_product_error=0.000704,
)

@skip_if_rocm("ROCm enablement in progress")
def test_hqq_plain_4bit(self):
self._test_hqq(
dtype=torch.uint4,
Expand Down
8 changes: 8 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
benchmark_model,
is_fbcode,
is_sm_at_least_90,
skip_if_rocm,
unwrap_tensor_subclass,
)

Expand All @@ -95,6 +96,7 @@
except ModuleNotFoundError:
has_gemlite = False


logger = logging.getLogger("INFO")

torch.manual_seed(0)
Expand Down Expand Up @@ -582,6 +584,7 @@ def test_per_token_linear_cpu(self):
self._test_per_token_linear_impl("cpu", dtype)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@skip_if_rocm("ROCm enablement in progress")
def test_per_token_linear_cuda(self):
for dtype in (torch.float32, torch.float16, torch.bfloat16):
self._test_per_token_linear_impl("cuda", dtype)
Expand Down Expand Up @@ -700,6 +703,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm enablement in progress")
def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
Expand All @@ -719,6 +723,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm enablement in progress")
def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
Expand Down Expand Up @@ -912,6 +917,7 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm enablement in progress")
def test_int4_weight_only_quant_subclass(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
Expand All @@ -931,6 +937,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm enablement in progress")
def test_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
Expand Down Expand Up @@ -1102,6 +1109,7 @@ def test_gemlite_layout(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm enablement in progress")
def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
Expand Down
3 changes: 3 additions & 0 deletions test/kernel/test_fused_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch
from galore_test_utils import get_kernel, make_copy, make_data

from torchao.utils import skip_if_rocm

torch.manual_seed(0)
MAX_DIFF_no_tf32 = 1e-5
MAX_DIFF_tf32 = 1e-3
Expand Down Expand Up @@ -104,6 +106,7 @@ def run_test(kernel, exp_avg, exp_avg2, grad, proj_matrix, params, allow_tf32):

@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU")
@pytest.mark.parametrize("kernel, dtype, M, N, rank, allow_tf32", TEST_CONFIGS)
@skip_if_rocm("ROCm enablement in progress")
def test_galore_fused_kernels(kernel, dtype, M, N, rank, allow_tf32):
torch.backends.cuda.matmul.allow_tf32 = allow_tf32

Expand Down
2 changes: 2 additions & 0 deletions test/kernel/test_galore_downproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk
from torchao.prototype.galore.kernels.matmul import triton_mm_launcher
from torchao.utils import skip_if_rocm

torch.manual_seed(0)

Expand All @@ -29,6 +30,7 @@

@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU")
@pytest.mark.parametrize("M, N, rank, allow_tf32, fp8_fast_accum, dtype", TEST_CONFIGS)
@skip_if_rocm("ROCm enablement in progress")
def test_galore_downproj(M, N, rank, allow_tf32, fp8_fast_accum, dtype):
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32
Expand Down
7 changes: 6 additions & 1 deletion test/prototype/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import torch

from torchao.quantization import quantize_
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_5,
skip_if_rocm,
)

if TORCH_VERSION_AT_LEAST_2_3:
from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_
Expand Down Expand Up @@ -113,6 +117,7 @@ def test_awq_loading(device, qdtype):

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@skip_if_rocm("ROCm enablement in progress")
def test_save_weights_only():
dataset_size = 100
l1, l2, l3 = 512, 256, 128
Expand Down
7 changes: 7 additions & 0 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
get_available_devices,
skip_if_rocm,
)

try:
Expand All @@ -42,6 +43,8 @@
except ImportError:
lpmm = None

if torch.version.hip is not None:
pytest.skip("Skipping the test in ROCm", allow_module_level=True)

_DEVICES = get_available_devices()

Expand Down Expand Up @@ -112,6 +115,7 @@ class TestOptim(TestCase):
)
@parametrize("dtype", [torch.float32, torch.bfloat16])
@parametrize("device", _DEVICES)
@skip_if_rocm("ROCm enablement in progress")
def test_optim_smoke(self, optim_name, dtype, device):
if optim_name.endswith("Fp8") and device == "cuda":
if not TORCH_VERSION_AT_LEAST_2_4:
Expand Down Expand Up @@ -185,6 +189,7 @@ def test_subclass_slice(self, subclass, shape, device):
not torch.cuda.is_available(),
reason="bitsandbytes 8-bit Adam only works for CUDA",
)
@skip_if_rocm("ROCm enablement in progress")
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
def test_optim_8bit_correctness(self, optim_name):
device = "cuda"
Expand Down Expand Up @@ -413,6 +418,7 @@ def world_size(self) -> int:
not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required."
)
@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
@skip_if_rocm("ROCm enablement in progress")
def test_fsdp2(self):
optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit]
if torch.cuda.get_device_capability() >= (8, 9):
Expand Down Expand Up @@ -523,6 +529,7 @@ def _test_fsdp2(self, optim_cls):
not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required."
)
@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
@skip_if_rocm("ROCm enablement in progress")
def test_uneven_shard(self):
in_dim = 512
out_dim = _FSDP_WORLD_SIZE * 16 + 1
Expand Down
3 changes: 3 additions & 0 deletions test/prototype/test_smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
TORCH_VERSION_AT_LEAST_2_5,
)

if torch.version.hip is not None:
pytest.skip("Skipping the test in ROCm", allow_module_level=True)


class ToyLinearModel(torch.nn.Module):
def __init__(self, m=512, n=256, k=128):
Expand Down
4 changes: 3 additions & 1 deletion test/prototype/test_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
except ImportError:
triton_available = False

from torchao.utils import skip_if_compute_capability_less_than

from torchao.utils import skip_if_compute_capability_less_than, skip_if_rocm


@unittest.skipIf(not triton_available, "Triton is required but not available")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
class TestFP8Gemm(TestCase):
@skip_if_compute_capability_less_than(9.0)
@skip_if_rocm("ROCm enablement in progress")
def test_gemm_split_k(self):
dtype = torch.float16
qdtype = torch.float8_e4m3fn
Expand Down
Loading
Loading