Skip to content

[Multi-arch] PyTorch packages build without flash attention when targets not supported by aotriton are included #4969

@ScottTodd

Description

@ScottTodd

Background

For the prior (non-multi-arch) releases, we built PyTorch independently for each GPU target family, e.g. gfx1151 and gfx110X-all. This would produce different torch packages for each GPU target family which we would then distribute via isolated pip index pages.

For the new multi-arch releases, we build PyTorch for all supported GPU targets in a single build and then split device-specific code out into separate packages as a post-processing step.

The issue

This release workflow run: https://github.com/ROCm/rockrel/actions/runs/25156704846 built with USE_FLASH_ATTENTION=0 and USE_MEM_EFF_ATTENTION=0

Here's the code where we decide whether to enable flash attention:

is_pytorch_2_9 = pytorch_build_version_parsed.release[:2] == (2, 9)
is_pytorch_2_11_or_later = pytorch_build_version_parsed.release[:2] >= (2, 11)
# aotriton is not supported on certain architectures yet.
# gfx900/gfx906/gfx908/gfx101X/gfx103X: https://github.com/ROCm/TheRock/issues/1925
AOTRITON_UNSUPPORTED_ARCHS = ["gfx900", "gfx906", "gfx908", "gfx101", "gfx103"]
# gfx1152/53: supported in aotriton 0.11.2b+ (https://github.com/ROCm/aotriton/pull/142),
# which is pinned by pytorch >= 2.11. Older versions don't include it.
if not is_pytorch_2_11_or_later:
AOTRITON_UNSUPPORTED_ARCHS += ["gfx1152", "gfx1153"]
## Enable FBGEMM_GENAI on Linux for PyTorch, as it is available only for 2.9 on rocm/pytorch
## and causes build failures for other PyTorch versions
## Warn user when enabling it manually.
## https://github.com/ROCm/TheRock/issues/2056
if not is_windows:
# Enabling/Disabling FBGEMM_GENAI based on Pytorch version in Linux
if is_pytorch_2_9:
# Default ON for 2.9.x, unless explicitly disabled
# args.enable_pytorch_fbgemm_genai_linux can be set to false
# by passing --no-enable-pytorch-fbgemm-genai-linux as input
if args.enable_pytorch_fbgemm_genai_linux is False:
use_fbgemm_genai = "OFF"
print(f" [WARN] User-requested override to set FBGEMM_GENAI = OFF.")
else:
use_fbgemm_genai = "ON"
else:
# Default OFF for all other versions, unless explicitly enabled
if args.enable_pytorch_fbgemm_genai_linux is True:
use_fbgemm_genai = "ON"
else:
use_fbgemm_genai = "OFF"
if use_fbgemm_genai == "ON":
print(f" [WARN] User-requested override to set FBGEMM_GENAI = ON.")
print(
f""" [WARN] Please note that FBGEMM_GENAI is not available for PyTorch 2.7, and enabling it may cause build failures
for PyTorch >= 2.8 (Except 2.9). See status of issue https://github.com/ROCm/TheRock/issues/2056
"""
)
env["USE_FBGEMM_GENAI"] = use_fbgemm_genai
print(f"FBGEMM_GENAI enabled: {env['USE_FBGEMM_GENAI'] == 'ON'}")
if args.enable_pytorch_flash_attention_linux is None:
# Default behavior — determined by if triton is build
use_flash_attention = "ON" if triton_requirement else "OFF"
if any(
arch in env["PYTORCH_ROCM_ARCH"] for arch in AOTRITON_UNSUPPORTED_ARCHS
):
use_flash_attention = "OFF"
print(
f"Flash Attention default behavior (based on triton and gpu): {use_flash_attention}"
)
else:
# Explicit override: user has set the flag to true/false
if args.enable_pytorch_flash_attention_linux:
assert (
triton_requirement
), "Must build with triton if wanting to use flash attention"
use_flash_attention = "ON"
else:
use_flash_attention = "OFF"
print(f"Flash Attention override set by flag: {use_flash_attention}")
env.update(
{
"USE_FLASH_ATTENTION": use_flash_attention,
"USE_MEM_EFF_ATTENTION": use_flash_attention,
}
)
print(
f"Flash Attention and Memory efficiency enabled: {env['USE_FLASH_ATTENTION'] == 'ON'}"
)
env["USE_ROCM"] = "ON"
env["USE_CUDA"] = "OFF"
env["USE_MPI"] = "OFF"
env["USE_NUMA"] = "OFF"
env["PYTORCH_BUILD_VERSION"] = pytorch_build_version
env["PYTORCH_BUILD_NUMBER"] = args.pytorch_build_number
# Determine which install requirements to add.
install_requirements = [
f"rocm[libraries]=={get_rocm_sdk_version()}",
]
if triton_requirement:
install_requirements.append(triton_requirement)
env["PYTORCH_EXTRA_INSTALL_REQUIREMENTS"] = "|".join(install_requirements)
print(
f"--- PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {env['PYTORCH_EXTRA_INSTALL_REQUIREMENTS']}"
)
# Add the _rocm_init.py file.
(pytorch_dir / "torch" / "_rocm_init.py").write_text(get_rocm_init_contents(args))
# Windows-specific settings.
if is_windows:
copy_msvc_libomp_to_torch_lib(pytorch_dir)
use_flash_attention = "0"
if args.enable_pytorch_flash_attention_windows and not any(
arch in env["PYTORCH_ROCM_ARCH"] for arch in AOTRITON_UNSUPPORTED_ARCHS
):
use_flash_attention = "1"
env.update(
{
"USE_FLASH_ATTENTION": use_flash_attention,
"USE_MEM_EFF_ATTENTION": use_flash_attention,
"DISTUTILS_USE_SDK": "1",
# Workaround compile errors in 'aten/src/ATen/test/hip/hip_vectorized_test.hip'
# on Torch 2.7.0: https://gist.github.com/ScottTodd/befdaf6c02a8af561f5ac1a2bc9c7a76.
# error: no member named 'modern' in namespace 'at::native'
# using namespace at::native::modern::detail;
# error: no template named 'has_same_arg_types'
# static_assert(has_same_arg_types<func1_t>::value, "func1_t has the same argument types");
# We may want to fix that and other issues to then enable building tests.
"BUILD_TEST": "0",
}
)
print(
f" Flash attention enabled: {args.enable_pytorch_flash_attention_windows or not is_windows}"
)

Notes:

  • If any arch in PYTORCH_ROCM_ARCH is in AOTRITON_UNSUPPORTED_ARCHS, the features are disabled completely
    • We should check if conditionally enabling or filtering later in the build is possible, this all-or-nothing filtering is particularly disruptive in the new style of building multi-arch packages
  • The logging is wrong on Windows - it reads args.enable_pytorch_flash_attention_windows and NOT the computed use_flash_attention

Other implications/symptoms

aotriton builds both target-specific and family-specific kernels, which kpack splitting now produces separate device packages for.

Our dev release index includes builds for different combinations of targets depending on what we're testing at https://rocm.devreleases.amd.com/whl-staging-multi-arch/ while our nightly release index always builds for all targets. The family-specific packages (from having aotriton enabled) are only in the devreleases index:

cc @xinyazhang , @HereThereBeDragons , @araravik-psd , @marbre

Metadata

Metadata

Assignees

Labels

ecosystem: PyTorchIssue pertains to PyTorch and related libraries

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions