diff --git a/.github/workflows/_linux-test-mi350.yml b/.github/workflows/_linux-test-mi350.yml index 6e5f51fb2..7199d2257 100644 --- a/.github/workflows/_linux-test-mi350.yml +++ b/.github/workflows/_linux-test-mi350.yml @@ -20,7 +20,6 @@ jobs: SETUP_SCRIPT: "/workspace/setup_instance.sh" CONDA_ENV: ${{ inputs.conda_env }} DOCKER_IMAGE: "ghcr.io/meta-pytorch/tritonbench:rocm-latest" - TRITON_HIP_USE_ASYNC_COPY: "0" steps: - name: Checkout Tritonbench uses: actions/checkout@v3 diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index 8b5c0e368..e6bf6baae 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -17,6 +17,7 @@ import torch import triton import triton.language as tl +from triton import knobs from .attention_utils import ( HAS_EXPLICIT_WS, # guard new tuning configs such as num_consumer_groups @@ -28,6 +29,20 @@ ) +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def is_hip_async_copy_enabled(): + if is_cuda(): + return False + + # default is enabled + if knobs.amd.use_async_copy is None: + return True + return knobs.amd.use_async_copy + + if HAS_TMA_DESC: print( "TMA benchmarks will be running with experimental grid constant TMA descriptor.", @@ -481,7 +496,7 @@ def get_fwd_config_space( bmList = [128] if enable_ws else [64, 128] bnList = [64, 128] # To handle hDim of 64, we need BLOCK_N to be <= 64 wList = [4] if enable_ws else [4, 8] - stageList = [2] if enable_ws else [3, 4, 7] + stageList = [2] if enable_ws else [3] if is_hip_async_copy_enabled() else [3, 4, 7] for BM in bmList: for BN in bnList: for sched in schedList: # set in global scope diff --git a/tritonbench/operators/fp8_gemm/persistent.py b/tritonbench/operators/fp8_gemm/persistent.py index e42aa0a69..550a9e870 100644 --- a/tritonbench/operators/fp8_gemm/persistent.py +++ b/tritonbench/operators/fp8_gemm/persistent.py @@ -5,6 +5,7 @@ import triton import triton.language as tl from torch._inductor.kernel.mm import ScalingType +from triton import knobs from tritonbench.utils.env_utils import is_cuda from tritonbench.utils.triton_utils import has_experimental_descriptor @@ -24,6 +25,16 @@ pass +def is_hip_async_copy_enabled(): + if is_cuda(): + return False + + # default is enabled + if knobs.amd.use_async_copy is None: + return True + return knobs.amd.use_async_copy + + def _matmul_launch_metadata(grid, kernel, args): ret = {} M, N, K = args["M"], args["N"], args["K"] @@ -135,7 +146,7 @@ def matmul_persistent(a, b): "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, - "num_stages": 4, + "num_stages": 3 if is_hip_async_copy_enabled() else 4, "num_warps": 8, }, torch.float16: { @@ -143,7 +154,7 @@ def matmul_persistent(a, b): "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, - "num_stages": 3, + "num_stages": 2 if is_hip_async_copy_enabled() else 3, "num_warps": 8, }, torch.bfloat16: { @@ -151,7 +162,7 @@ def matmul_persistent(a, b): "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, - "num_stages": 3, + "num_stages": 2 if is_hip_async_copy_enabled() else 3, "num_warps": 8, }, } diff --git a/tritonbench/operators/grouped_gemm/kernels.py b/tritonbench/operators/grouped_gemm/kernels.py index beed269c1..28731c8ba 100644 --- a/tritonbench/operators/grouped_gemm/kernels.py +++ b/tritonbench/operators/grouped_gemm/kernels.py @@ -32,6 +32,7 @@ import torch import triton import triton.language as tl +from triton import knobs try: # @manual=//triton:triton @@ -47,9 +48,17 @@ def is_cuda(): def num_sms(): + return torch.cuda.get_device_properties("cuda").multi_processor_count + + +def is_hip_async_copy_enabled(): if is_cuda(): - return torch.cuda.get_device_properties("cuda").multi_processor_count - return 148 + return False + + # default is enabled + if knobs.amd.use_async_copy is None: + return True + return knobs.amd.use_async_copy def torch_dtype_to_triton_dtype(dtype): @@ -73,7 +82,8 @@ def torch_dtype_to_triton_dtype(dtype): "BLOCK_SIZE_N": BLOCK_N, "BLOCK_SIZE_K": BLOCK_K, "NUM_SMS": num_sms(), - } + }, + num_stages=2 if is_hip_async_copy_enabled() else 3, ) for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product([128, 256], repeat=3) ],