Skip to content

Commit

Permalink
Reverse the sdpa hijcak order
Browse files Browse the repository at this point in the history
  • Loading branch information
Disty0 committed Feb 14, 2025
1 parent 342a757 commit 878cab0
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions modules/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,24 +426,18 @@ def set_sdpa_params():
except Exception as err:
log.warning(f'Torch attention: type="sdpa" {err}')

# stack hijcaks with order: Sage -> CK Flash -> Dynamic
# if the first is not compatible, uses the second and so on
# Stack hijcaks in reverse order. This gives priority to the last added hijack.
# If the last hijack is not compatible, it will use the one before it and so on.

if 'Sage attention' in opts.sdp_options:
if 'Dynamic attention' in opts.sdp_options:
try:
install('sageattention')
from sageattention import sageattn
sdpa_pre_sage_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_sage_atten)
def sdpa_sage_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
if (query.shape[-1] in {128, 96, 64}) and (attn_mask is None) and (query.dtype != torch.float32):
return sageattn(q=query, k=key, v=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
else:
return sdpa_pre_sage_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
torch.nn.functional.scaled_dot_product_attention = sdpa_sage_atten
log.debug('Torch attention: type="sage attention"')
global sdpa_pre_dyanmic_atten # pylint: disable=global-statement
sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention
from modules.sd_hijack_dynamic_atten import dynamic_scaled_dot_product_attention
torch.nn.functional.scaled_dot_product_attention = dynamic_scaled_dot_product_attention
log.debug('Torch attention: type="dynamic attention"')
except Exception as err:
log.error(f'Torch attention: type="sage attention" {err}')
log.error(f'Torch attention: type="dynamic attention" {err}')

if 'CK Flash attention' in opts.sdp_options:
try:
Expand All @@ -466,15 +460,21 @@ def sdpa_flash_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal
except Exception as err:
log.error(f'Torch attention: type="ck flash attention" {err}')

if 'Dynamic attention' in opts.sdp_options:
if 'Sage attention' in opts.sdp_options:
try:
global sdpa_pre_dyanmic_atten # pylint: disable=global-statement
sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention
from modules.sd_hijack_dynamic_atten import dynamic_scaled_dot_product_attention
torch.nn.functional.scaled_dot_product_attention = dynamic_scaled_dot_product_attention
log.debug('Torch attention: type="dynamic attention"')
install('sageattention')
from sageattention import sageattn
sdpa_pre_sage_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_sage_atten)
def sdpa_sage_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
if (query.shape[-1] in {128, 96, 64}) and (attn_mask is None) and (query.dtype != torch.float32):
return sageattn(q=query, k=key, v=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
else:
return sdpa_pre_sage_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
torch.nn.functional.scaled_dot_product_attention = sdpa_sage_atten
log.debug('Torch attention: type="sage attention"')
except Exception as err:
log.error(f'Torch attention: type="dynamic attention" {err}')
log.error(f'Torch attention: type="sage attention" {err}')
except Exception as e:
log.warning(f'Torch SDPA: {e}')

Expand Down

0 comments on commit 878cab0

Please sign in to comment.