-
Notifications
You must be signed in to change notification settings - Fork 744
Description
Summary
xFormers has an import conflict in xformers/ops/fmha/flash3.py: When the global flash_attn_3 package is imported before xFormers, attempting to import xformers.ops.fmha.flash3 fails due to torch extension namespace conflicts.
Bug Description
The xFormers library attempts to support both a vendored (bundled) version of Flash Attention 3 and the global pip-installed flash_attn_3 package. However, the import logic in xformers/ops/fmha/flash3.py prioritizes the vendored version first, which causes a conflict when the global package has already been imported and registered its torch operations.
Both packages attempt to register torch operations under the same namespace torch.ops.flash_attn_3, leading to registration conflicts and import failures.
The conflict is bidirectional:
- If
flash_attn_3is imported first, then importing xFormers fails - If xFormers is imported first, then importing
flash_attn_3fails
Real-world example: This issue occurs when importing diffusers.models.autoencoders, which internally imports flash_attn_3, if available, and xFormers. The import failure prevents using both libraries together.
Steps to Reproduce
- Ensure both
flash_attn_3andxformersare installed - Import
flash_attn_3first, which registerstorch.ops.flash_attn_3 - Then attempt to import
xformers.ops.fmha.flash3 - Import conflicts occur due to the namespace collision
Minimal Reproduction
import flash_attn_3 # Registers torch.ops.flash_attn_3
import xformers.ops.fmha.flash3 # Fails with RuntimeErrorExpected Behavior
- Importing
xformers.ops.fmha.flash3should work regardless of whetherflash_attn_3has been imported first - The library should detect when the global package is already available and use it without conflicts
Actual Behavior
- When
flash_attn_3is imported first, subsequent import ofxformers.ops.fmha.flash3fails - The current import order attempts to import the vendored version first (lines 107-113), which conflicts with the already-registered global package
- The global package import is only attempted if the vendored version is not present (lines 115-127)
Root Cause
In xformers/ops/fmha/flash3.py (lines 106-127):
_C_flashattention3 = None
if importlib.util.find_spec("...flash_attn_3._C", package=__package__):
# Vendored version - tries to import first
from ..._cpp_lib import _build_metadata
from ...flash_attn_3 import _C # Attempts to register torch.ops.flash_attn_3
_C_flashattention3 = torch.ops.flash_attn_3
elif importlib.util.find_spec("flash_attn_3") and importlib.util.find_spec("flash_attn_3._C"):
# Global version - only tried if vendored is not present
import flash_attn_3._C # Conflicts if already imported
_C_flashattention3 = torch.ops.flash_attn_3The problem:
- If the global
flash_attn_3is imported first, it registerstorch.ops.flash_attn_3 - When xFormers tries to import its vendored version, it attempts to re-register the same namespace
- This causes torch extension conflicts
Actual Error Traceback
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.11/site-packages/xformers/ops/fmha/flash3.py", line 111, in <module>
from ...flash_attn_3 import _C
File "/usr/local/lib/python3.11/site-packages/xformers/flash_attn_3/__init__.py", line 5, in <module>
from . import _C
RuntimeError: Library "flash_attn_3" already registered! This usually means that
you are trying to load two different versions of the same library, or that the
library has already been loaded. The torch.ops registry cannot contain duplicate
entries for the same operator namespace.
Environment
PyTorch Environment Details (click to expand)
PyTorch version: 2.7.1+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Python version: 3.11.10 (64-bit runtime)
Python platform: Linux-6.8.0-1040-aws-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: NVIDIA L40S
Nvidia driver version: 570.172.08
cuDNN version: 9.14.0
Versions of relevant libraries:
torch==2.7.1
torchvision==0.22.1
torchaudio==2.7.1
triton==3.3.1
Note Flash Attention 3 is not shipped as a released package, but can be build following these instructions