Skip to content

Flash Attention 3 torch extension namespace conflict #1348

@georgkaleido

Description

@georgkaleido

Summary

xFormers has an import conflict in xformers/ops/fmha/flash3.py: When the global flash_attn_3 package is imported before xFormers, attempting to import xformers.ops.fmha.flash3 fails due to torch extension namespace conflicts.

Bug Description

The xFormers library attempts to support both a vendored (bundled) version of Flash Attention 3 and the global pip-installed flash_attn_3 package. However, the import logic in xformers/ops/fmha/flash3.py prioritizes the vendored version first, which causes a conflict when the global package has already been imported and registered its torch operations.

Both packages attempt to register torch operations under the same namespace torch.ops.flash_attn_3, leading to registration conflicts and import failures.

The conflict is bidirectional:

  • If flash_attn_3 is imported first, then importing xFormers fails
  • If xFormers is imported first, then importing flash_attn_3 fails

Real-world example: This issue occurs when importing diffusers.models.autoencoders, which internally imports flash_attn_3, if available, and xFormers. The import failure prevents using both libraries together.

Steps to Reproduce

  1. Ensure both flash_attn_3 and xformers are installed
  2. Import flash_attn_3 first, which registers torch.ops.flash_attn_3
  3. Then attempt to import xformers.ops.fmha.flash3
  4. Import conflicts occur due to the namespace collision

Minimal Reproduction

import flash_attn_3  # Registers torch.ops.flash_attn_3
import xformers.ops.fmha.flash3  # Fails with RuntimeError

Expected Behavior

  • Importing xformers.ops.fmha.flash3 should work regardless of whether flash_attn_3 has been imported first
  • The library should detect when the global package is already available and use it without conflicts

Actual Behavior

  • When flash_attn_3 is imported first, subsequent import of xformers.ops.fmha.flash3 fails
  • The current import order attempts to import the vendored version first (lines 107-113), which conflicts with the already-registered global package
  • The global package import is only attempted if the vendored version is not present (lines 115-127)

Root Cause

In xformers/ops/fmha/flash3.py (lines 106-127):

_C_flashattention3 = None
if importlib.util.find_spec("...flash_attn_3._C", package=__package__):
    # Vendored version - tries to import first
    from ..._cpp_lib import _build_metadata
    from ...flash_attn_3 import _C  # Attempts to register torch.ops.flash_attn_3
    _C_flashattention3 = torch.ops.flash_attn_3

elif importlib.util.find_spec("flash_attn_3") and importlib.util.find_spec("flash_attn_3._C"):
    # Global version - only tried if vendored is not present
    import flash_attn_3._C  # Conflicts if already imported
    _C_flashattention3 = torch.ops.flash_attn_3

The problem:

  1. If the global flash_attn_3 is imported first, it registers torch.ops.flash_attn_3
  2. When xFormers tries to import its vendored version, it attempts to re-register the same namespace
  3. This causes torch extension conflicts

Actual Error Traceback

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.11/site-packages/xformers/ops/fmha/flash3.py", line 111, in <module>
    from ...flash_attn_3 import _C
  File "/usr/local/lib/python3.11/site-packages/xformers/flash_attn_3/__init__.py", line 5, in <module>
    from . import _C
RuntimeError: Library "flash_attn_3" already registered! This usually means that
you are trying to load two different versions of the same library, or that the
library has already been loaded. The torch.ops registry cannot contain duplicate
entries for the same operator namespace.

Environment

PyTorch Environment Details (click to expand)
PyTorch version: 2.7.1+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Python version: 3.11.10 (64-bit runtime)
Python platform: Linux-6.8.0-1040-aws-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: NVIDIA L40S
Nvidia driver version: 570.172.08
cuDNN version: 9.14.0

Versions of relevant libraries:
torch==2.7.1
torchvision==0.22.1
torchaudio==2.7.1
triton==3.3.1

Note Flash Attention 3 is not shipped as a released package, but can be build following these instructions

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions