Skip to content

Commit 7f0c95b

Browse files
committed
add flash attention support using triton
1 parent 9bbf28e commit 7f0c95b

File tree

2 files changed

+63
-33
lines changed

2 files changed

+63
-33
lines changed

modules/rocm_triton_windows.py

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import sys
2-
from functools import wraps
32
import torch
4-
from modules import shared, devices
3+
from modules import devices
4+
5+
6+
is_available = False
57

68

79
if sys.platform == "win32":
@@ -94,34 +96,7 @@ def triton_runtime_driver_active_utils_get_device_properties(device):
9496
return props
9597
triton.runtime.driver.active.utils.get_device_properties = triton_runtime_driver_active_utils_get_device_properties
9698

97-
if 'Flash attention' in shared.opts.sdp_options:
98-
from modules.flash_attn_triton_amd import interface_fa
99-
sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
100-
@wraps(sdpa_pre_flash_atten)
101-
def sdpa_flash_atten(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
102-
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
103-
if scale is None:
104-
scale = query.shape[-1] ** (-0.5)
105-
head_size_og = query.size(3)
106-
if head_size_og % 8 != 0:
107-
query = torch.nn.functional.pad(query, [0, 8 - head_size_og % 8])
108-
key = torch.nn.functional.pad(key, [0, 8 - head_size_og % 8])
109-
value = torch.nn.functional.pad(value, [0, 8 - head_size_og % 8])
110-
query = query.transpose(1, 2)
111-
out_padded = torch.zeros_like(query)
112-
interface_fa.fwd(
113-
query,
114-
key.transpose(1, 2),
115-
value.transpose(1, 2),
116-
out_padded,
117-
dropout_p,
118-
scale,
119-
is_causal,
120-
)
121-
return out_padded[..., :head_size_og].transpose(1, 2)
122-
else:
123-
return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
124-
torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten
125-
print('Torch attention: type="triton flash attention"')
99+
global is_available
100+
is_available = True
126101
except Exception:
127102
pass

modules/sd_hijack_optimizations.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
import math
33
import psutil
44
import platform
5+
from functools import wraps
56

67
import torch
78
from torch import einsum
89

910
from ldm.util import default
1011
from einops import rearrange
1112

12-
from modules import shared, errors, devices, sub_quadratic_attention
13+
from modules import shared, errors, devices, sub_quadratic_attention, rocm_triton_windows
1314
from modules.hypernetworks import hypernetwork
1415

1516
import ldm.modules.attention
@@ -34,7 +35,7 @@ def title(self):
3435

3536
return f"{self.name} - {self.label}"
3637

37-
def is_available(self):
38+
def is_available(self) -> bool:
3839
return True
3940

4041
def apply(self):
@@ -143,8 +144,62 @@ def apply(self):
143144
sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
144145

145146

147+
class SdOptimizationTritonFlashAttention(SdOptimization):
148+
name = "Flash attention"
149+
cmd_opt = "flash_attn"
150+
priority = 100
151+
152+
def __init__(self):
153+
super().__init__()
154+
self.sdpa_pre_flash_atten = None
155+
156+
def is_available(self):
157+
return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) and devices.has_zluda() and rocm_triton_windows.is_available
158+
159+
def apply(self):
160+
if self.sdpa_pre_flash_atten is None:
161+
from modules.flash_attn_triton_amd import interface_fa
162+
self.sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
163+
@wraps(self.sdpa_pre_flash_atten)
164+
def sdpa_flash_atten(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
165+
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
166+
if scale is None:
167+
scale = query.shape[-1] ** (-0.5)
168+
head_size_og = query.size(3)
169+
if head_size_og % 8 != 0:
170+
query = torch.nn.functional.pad(query, [0, 8 - head_size_og % 8])
171+
key = torch.nn.functional.pad(key, [0, 8 - head_size_og % 8])
172+
value = torch.nn.functional.pad(value, [0, 8 - head_size_og % 8])
173+
query = query.transpose(1, 2)
174+
out_padded = torch.zeros_like(query)
175+
interface_fa.fwd(
176+
query,
177+
key.transpose(1, 2),
178+
value.transpose(1, 2),
179+
out_padded,
180+
dropout_p,
181+
scale,
182+
is_causal,
183+
)
184+
return out_padded[..., :head_size_og].transpose(1, 2)
185+
else:
186+
return self.sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
187+
torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten
188+
189+
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
190+
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
191+
sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
192+
sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
193+
194+
def undo(self):
195+
super().undo()
196+
torch.nn.functional.scaled_dot_product_attention = self.sdpa_pre_flash_atten
197+
self.sdpa_pre_flash_atten = None
198+
199+
146200
def list_optimizers(res):
147201
res.extend([
202+
SdOptimizationTritonFlashAttention(),
148203
SdOptimizationXformers(),
149204
SdOptimizationSdpNoMem(),
150205
SdOptimizationSdp(),

0 commit comments

Comments
 (0)