22import math
33import psutil
44import platform
5+ from functools import wraps
56
67import torch
78from torch import einsum
89
910from ldm .util import default
1011from einops import rearrange
1112
12- from modules import shared , errors , devices , sub_quadratic_attention
13+ from modules import shared , errors , devices , sub_quadratic_attention , rocm_triton_windows
1314from modules .hypernetworks import hypernetwork
1415
1516import ldm .modules .attention
@@ -34,7 +35,7 @@ def title(self):
3435
3536 return f"{ self .name } - { self .label } "
3637
37- def is_available (self ):
38+ def is_available (self ) -> bool :
3839 return True
3940
4041 def apply (self ):
@@ -143,8 +144,62 @@ def apply(self):
143144 sgm .modules .diffusionmodules .model .AttnBlock .forward = cross_attention_attnblock_forward
144145
145146
147+ class SdOptimizationTritonFlashAttention (SdOptimization ):
148+ name = "Flash attention"
149+ cmd_opt = "flash_attn"
150+ priority = 100
151+
152+ def __init__ (self ):
153+ super ().__init__ ()
154+ self .sdpa_pre_flash_atten = None
155+
156+ def is_available (self ):
157+ return hasattr (torch .nn .functional , "scaled_dot_product_attention" ) and callable (torch .nn .functional .scaled_dot_product_attention ) and devices .has_zluda () and rocm_triton_windows .is_available
158+
159+ def apply (self ):
160+ if self .sdpa_pre_flash_atten is None :
161+ from modules .flash_attn_triton_amd import interface_fa
162+ self .sdpa_pre_flash_atten = torch .nn .functional .scaled_dot_product_attention
163+ @wraps (self .sdpa_pre_flash_atten )
164+ def sdpa_flash_atten (query : torch .Tensor , key : torch .Tensor , value : torch .Tensor , attn_mask = None , dropout_p = 0.0 , is_causal = False , scale = None ):
165+ if query .shape [- 1 ] <= 128 and attn_mask is None and query .dtype != torch .float32 :
166+ if scale is None :
167+ scale = query .shape [- 1 ] ** (- 0.5 )
168+ head_size_og = query .size (3 )
169+ if head_size_og % 8 != 0 :
170+ query = torch .nn .functional .pad (query , [0 , 8 - head_size_og % 8 ])
171+ key = torch .nn .functional .pad (key , [0 , 8 - head_size_og % 8 ])
172+ value = torch .nn .functional .pad (value , [0 , 8 - head_size_og % 8 ])
173+ query = query .transpose (1 , 2 )
174+ out_padded = torch .zeros_like (query )
175+ interface_fa .fwd (
176+ query ,
177+ key .transpose (1 , 2 ),
178+ value .transpose (1 , 2 ),
179+ out_padded ,
180+ dropout_p ,
181+ scale ,
182+ is_causal ,
183+ )
184+ return out_padded [..., :head_size_og ].transpose (1 , 2 )
185+ else :
186+ return self .sdpa_pre_flash_atten (query = query , key = key , value = value , attn_mask = attn_mask , dropout_p = dropout_p , is_causal = is_causal , scale = scale )
187+ torch .nn .functional .scaled_dot_product_attention = sdpa_flash_atten
188+
189+ ldm .modules .attention .CrossAttention .forward = scaled_dot_product_attention_forward
190+ ldm .modules .diffusionmodules .model .AttnBlock .forward = sdp_attnblock_forward
191+ sgm .modules .attention .CrossAttention .forward = scaled_dot_product_attention_forward
192+ sgm .modules .diffusionmodules .model .AttnBlock .forward = sdp_attnblock_forward
193+
194+ def undo (self ):
195+ super ().undo ()
196+ torch .nn .functional .scaled_dot_product_attention = self .sdpa_pre_flash_atten
197+ self .sdpa_pre_flash_atten = None
198+
199+
146200def list_optimizers (res ):
147201 res .extend ([
202+ SdOptimizationTritonFlashAttention (),
148203 SdOptimizationXformers (),
149204 SdOptimizationSdpNoMem (),
150205 SdOptimizationSdp (),
0 commit comments