Skip to content

enable_xformers_memory_efficient_attention in the training script #11486

Open
@master861

Description

@master861

Describe the bug

With --enable_xformers_memory_efficient_attention in the sdxl dreambooth the script crashes.

Reproduction

--use_8bit_adam --push_to_hub --enable_xformers_memory_efficient_attention

Logs

raceback (most recent call last):
  File "E:\diffusers\examples\dreambooth\train_dreambooth_lora_sdxl.py", line 2021, in <module>
    main(args)
  File "E:\diffusers\examples\dreambooth\train_dreambooth_lora_sdxl.py", line 1731, in main
    model_pred = unet(
                 ^^^^^
  File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\accelerate\utils\operations.py", line 814, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\accelerate\utils\operations.py", line 802, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\torch\amp\autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\diffusers\models\unets\unet_2d_condition.py", line 1214, in forward
    sample, res_samples = downsample_block(
                          ^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\diffusers\models\unets\unet_2d_blocks.py", line 1260, in forward
    hidden_states = attn(
                    ^^^^^
  File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\diffusers\models\transformers\transformer_2d.py", line 416, in forward
    hidden_states = self._gradient_checkpointing_func(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\diffusers\models\modeling_utils.py", line 320, in _gradient_checkpointing_func
    return torch.utils.checkpoint.checkpoint(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\torch\_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\torch\_dynamo\eval_frame.py", line 838, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\torch\utils\checkpoint.py", line 495, in checkpoint
    ret = function(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\diffusers\models\attention.py", line 552, in forward
    attn_output = self.attn2(
                  ^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\diffusers\models\attention_processor.py", line 605, in forward
    return self.processor(
           ^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\diffusers\models\attention_processor.py", line 3106, in __call__
    hidden_states = xformers.ops.memory_efficient_attention(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\xformers\ops\fmha\__init__.py", line 306, in memory_efficient_attention
    return _memory_efficient_attention(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\xformers\ops\fmha\__init__.py", line 475, in _memory_efficient_attention
    return _fMHA.apply(
           ^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\torch\autograd\function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\xformers\ops\fmha\__init__.py", line 89, in forward
    out, op_ctx = _memory_efficient_attention_forward_requires_grad(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Python311\Lib\site-packages\xformers\ops\fmha\__init__.py", line 497, in _memory_efficient_attention_forward_requires_grad
    inp.validate_inputs()
  File "D:\Python311\Lib\site-packages\xformers\ops\fmha\common.py", line 240, in validate_inputs
    raise ValueError(

System Info

  • 🤗 Diffusers version: 0.34.0.dev0
  • Platform: Windows-10-10.0.26100-SP0
  • Running on Google Colab?: No
  • Python version: 3.11.9
  • PyTorch version (GPU?): 2.7.0+cu128 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.30.2
  • Transformers version: 4.51.3
  • Accelerate version: 1.6.0
  • PEFT version: 0.7.0
  • Bitsandbytes version: 0.45.5
  • Safetensors version: 0.5.3
  • xFormers version: 0.0.30

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions