Open
Description
Describe the bug
With --enable_xformers_memory_efficient_attention in the sdxl dreambooth the script crashes.
Reproduction
--use_8bit_adam --push_to_hub --enable_xformers_memory_efficient_attention
Logs
raceback (most recent call last):
File "E:\diffusers\examples\dreambooth\train_dreambooth_lora_sdxl.py", line 2021, in <module>
main(args)
File "E:\diffusers\examples\dreambooth\train_dreambooth_lora_sdxl.py", line 1731, in main
model_pred = unet(
^^^^^
File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\accelerate\utils\operations.py", line 814, in forward
return model_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\accelerate\utils\operations.py", line 802, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\torch\amp\autocast_mode.py", line 44, in decorate_autocast
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\diffusers\models\unets\unet_2d_condition.py", line 1214, in forward
sample, res_samples = downsample_block(
^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\diffusers\models\unets\unet_2d_blocks.py", line 1260, in forward
hidden_states = attn(
^^^^^
File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\diffusers\models\transformers\transformer_2d.py", line 416, in forward
hidden_states = self._gradient_checkpointing_func(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\diffusers\models\modeling_utils.py", line 320, in _gradient_checkpointing_func
return torch.utils.checkpoint.checkpoint(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\torch\_compile.py", line 51, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\torch\_dynamo\eval_frame.py", line 838, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\torch\utils\checkpoint.py", line 495, in checkpoint
ret = function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\diffusers\models\attention.py", line 552, in forward
attn_output = self.attn2(
^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\diffusers\models\attention_processor.py", line 605, in forward
return self.processor(
^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\diffusers\models\attention_processor.py", line 3106, in __call__
hidden_states = xformers.ops.memory_efficient_attention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\xformers\ops\fmha\__init__.py", line 306, in memory_efficient_attention
return _memory_efficient_attention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\xformers\ops\fmha\__init__.py", line 475, in _memory_efficient_attention
return _fMHA.apply(
^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\torch\autograd\function.py", line 575, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\xformers\ops\fmha\__init__.py", line 89, in forward
out, op_ctx = _memory_efficient_attention_forward_requires_grad(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Python311\Lib\site-packages\xformers\ops\fmha\__init__.py", line 497, in _memory_efficient_attention_forward_requires_grad
inp.validate_inputs()
File "D:\Python311\Lib\site-packages\xformers\ops\fmha\common.py", line 240, in validate_inputs
raise ValueError(
System Info
- 🤗 Diffusers version: 0.34.0.dev0
- Platform: Windows-10-10.0.26100-SP0
- Running on Google Colab?: No
- Python version: 3.11.9
- PyTorch version (GPU?): 2.7.0+cu128 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.30.2
- Transformers version: 4.51.3
- Accelerate version: 1.6.0
- PEFT version: 0.7.0
- Bitsandbytes version: 0.45.5
- Safetensors version: 0.5.3
- xFormers version: 0.0.30
Who can help?
No response