|
is_pytorch_2_9 = pytorch_build_version_parsed.release[:2] == (2, 9) |
|
is_pytorch_2_11_or_later = pytorch_build_version_parsed.release[:2] >= (2, 11) |
|
|
|
# aotriton is not supported on certain architectures yet. |
|
# gfx900/gfx906/gfx908/gfx101X/gfx103X: https://github.com/ROCm/TheRock/issues/1925 |
|
AOTRITON_UNSUPPORTED_ARCHS = ["gfx900", "gfx906", "gfx908", "gfx101", "gfx103"] |
|
# gfx1152/53: supported in aotriton 0.11.2b+ (https://github.com/ROCm/aotriton/pull/142), |
|
# which is pinned by pytorch >= 2.11. Older versions don't include it. |
|
if not is_pytorch_2_11_or_later: |
|
AOTRITON_UNSUPPORTED_ARCHS += ["gfx1152", "gfx1153"] |
|
|
|
## Enable FBGEMM_GENAI on Linux for PyTorch, as it is available only for 2.9 on rocm/pytorch |
|
## and causes build failures for other PyTorch versions |
|
## Warn user when enabling it manually. |
|
## https://github.com/ROCm/TheRock/issues/2056 |
|
if not is_windows: |
|
# Enabling/Disabling FBGEMM_GENAI based on Pytorch version in Linux |
|
if is_pytorch_2_9: |
|
# Default ON for 2.9.x, unless explicitly disabled |
|
# args.enable_pytorch_fbgemm_genai_linux can be set to false |
|
# by passing --no-enable-pytorch-fbgemm-genai-linux as input |
|
if args.enable_pytorch_fbgemm_genai_linux is False: |
|
use_fbgemm_genai = "OFF" |
|
print(f" [WARN] User-requested override to set FBGEMM_GENAI = OFF.") |
|
else: |
|
use_fbgemm_genai = "ON" |
|
else: |
|
# Default OFF for all other versions, unless explicitly enabled |
|
if args.enable_pytorch_fbgemm_genai_linux is True: |
|
use_fbgemm_genai = "ON" |
|
else: |
|
use_fbgemm_genai = "OFF" |
|
|
|
if use_fbgemm_genai == "ON": |
|
print(f" [WARN] User-requested override to set FBGEMM_GENAI = ON.") |
|
print( |
|
f""" [WARN] Please note that FBGEMM_GENAI is not available for PyTorch 2.7, and enabling it may cause build failures |
|
for PyTorch >= 2.8 (Except 2.9). See status of issue https://github.com/ROCm/TheRock/issues/2056 |
|
""" |
|
) |
|
|
|
env["USE_FBGEMM_GENAI"] = use_fbgemm_genai |
|
print(f"FBGEMM_GENAI enabled: {env['USE_FBGEMM_GENAI'] == 'ON'}") |
|
|
|
if args.enable_pytorch_flash_attention_linux is None: |
|
# Default behavior — determined by if triton is build |
|
use_flash_attention = "ON" if triton_requirement else "OFF" |
|
|
|
if any( |
|
arch in env["PYTORCH_ROCM_ARCH"] for arch in AOTRITON_UNSUPPORTED_ARCHS |
|
): |
|
use_flash_attention = "OFF" |
|
print( |
|
f"Flash Attention default behavior (based on triton and gpu): {use_flash_attention}" |
|
) |
|
else: |
|
# Explicit override: user has set the flag to true/false |
|
if args.enable_pytorch_flash_attention_linux: |
|
assert ( |
|
triton_requirement |
|
), "Must build with triton if wanting to use flash attention" |
|
use_flash_attention = "ON" |
|
else: |
|
use_flash_attention = "OFF" |
|
|
|
print(f"Flash Attention override set by flag: {use_flash_attention}") |
|
|
|
env.update( |
|
{ |
|
"USE_FLASH_ATTENTION": use_flash_attention, |
|
"USE_MEM_EFF_ATTENTION": use_flash_attention, |
|
} |
|
) |
|
print( |
|
f"Flash Attention and Memory efficiency enabled: {env['USE_FLASH_ATTENTION'] == 'ON'}" |
|
) |
|
|
|
env["USE_ROCM"] = "ON" |
|
env["USE_CUDA"] = "OFF" |
|
env["USE_MPI"] = "OFF" |
|
env["USE_NUMA"] = "OFF" |
|
env["PYTORCH_BUILD_VERSION"] = pytorch_build_version |
|
env["PYTORCH_BUILD_NUMBER"] = args.pytorch_build_number |
|
|
|
# Determine which install requirements to add. |
|
install_requirements = [ |
|
f"rocm[libraries]=={get_rocm_sdk_version()}", |
|
] |
|
if triton_requirement: |
|
install_requirements.append(triton_requirement) |
|
env["PYTORCH_EXTRA_INSTALL_REQUIREMENTS"] = "|".join(install_requirements) |
|
print( |
|
f"--- PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {env['PYTORCH_EXTRA_INSTALL_REQUIREMENTS']}" |
|
) |
|
|
|
# Add the _rocm_init.py file. |
|
(pytorch_dir / "torch" / "_rocm_init.py").write_text(get_rocm_init_contents(args)) |
|
|
|
# Windows-specific settings. |
|
if is_windows: |
|
copy_msvc_libomp_to_torch_lib(pytorch_dir) |
|
|
|
use_flash_attention = "0" |
|
|
|
if args.enable_pytorch_flash_attention_windows and not any( |
|
arch in env["PYTORCH_ROCM_ARCH"] for arch in AOTRITON_UNSUPPORTED_ARCHS |
|
): |
|
use_flash_attention = "1" |
|
|
|
env.update( |
|
{ |
|
"USE_FLASH_ATTENTION": use_flash_attention, |
|
"USE_MEM_EFF_ATTENTION": use_flash_attention, |
|
"DISTUTILS_USE_SDK": "1", |
|
# Workaround compile errors in 'aten/src/ATen/test/hip/hip_vectorized_test.hip' |
|
# on Torch 2.7.0: https://gist.github.com/ScottTodd/befdaf6c02a8af561f5ac1a2bc9c7a76. |
|
# error: no member named 'modern' in namespace 'at::native' |
|
# using namespace at::native::modern::detail; |
|
# error: no template named 'has_same_arg_types' |
|
# static_assert(has_same_arg_types<func1_t>::value, "func1_t has the same argument types"); |
|
# We may want to fix that and other issues to then enable building tests. |
|
"BUILD_TEST": "0", |
|
} |
|
) |
|
print( |
|
f" Flash attention enabled: {args.enable_pytorch_flash_attention_windows or not is_windows}" |
|
) |
Background
For the prior (non-multi-arch) releases, we built PyTorch independently for each GPU target family, e.g. gfx1151 and gfx110X-all. This would produce different torch packages for each GPU target family which we would then distribute via isolated pip index pages.
For the new multi-arch releases, we build PyTorch for all supported GPU targets in a single build and then split device-specific code out into separate packages as a post-processing step.
PYTORCH_ROCM_ARCH=gfx1100;gfx1101;gfx1102;gfx1103;gfx1151;gfx1200;gfx1201;gfx900;gfx906;gfx908;gfx90a;gfx1010;gfx1011;gfx1012;gfx1030;gfx1031;gfx1032;gfx1033;gfx1034;gfx1035;gfx1036;gfx1150;gfx1152;gfx1153The issue
This release workflow run: https://github.com/ROCm/rockrel/actions/runs/25156704846 built with
USE_FLASH_ATTENTION=0andUSE_MEM_EFF_ATTENTION=0Here's the code where we decide whether to enable flash attention:
TheRock/external-builds/pytorch/build_prod_wheels.py
Lines 939 to 1065 in 4f0fe3e
Notes:
PYTORCH_ROCM_ARCHis inAOTRITON_UNSUPPORTED_ARCHS, the features are disabled completelyargs.enable_pytorch_flash_attention_windowsand NOT the computeduse_flash_attentionOther implications/symptoms
aotriton builds both target-specific and family-specific kernels, which kpack splitting now produces separate device packages for.
Our dev release index includes builds for different combinations of targets depending on what we're testing at https://rocm.devreleases.amd.com/whl-staging-multi-arch/ while our nightly release index always builds for all targets. The family-specific packages (from having aotriton enabled) are only in the devreleases index:
torch\lib\aotriton.images\amd-gfx120x\flash\attn_fwd\cc @xinyazhang , @HereThereBeDragons , @araravik-psd , @marbre