-
Notifications
You must be signed in to change notification settings - Fork 473
Added backend context manager to select SDPA implementation #1061
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,12 +22,14 @@ | |
| import contextlib | ||
| import importlib | ||
| import math | ||
| from typing import Any, Dict, List, Set | ||
| from contextvars import ContextVar | ||
| from typing import Any, Dict, Generator, List, Literal, Set | ||
|
|
||
| import numpy as np | ||
| import nvtx | ||
| import torch | ||
| from einops import rearrange | ||
| from torch.nn.attention import SDPBackend, sdpa_kernel | ||
| from torch.nn.functional import elu, gelu, leaky_relu, relu, sigmoid, silu, tanh | ||
|
|
||
| from physicsnemo.models.diffusion import weight_init | ||
|
|
@@ -656,6 +658,14 @@ class Attention(torch.nn.Module): | |
| ------- | ||
| torch.Tensor | ||
| Output tensor of the same shape as input: :math:`(B, C, H, W)`. | ||
|
|
||
| .. important:: | ||
|
|
||
| This implementation uses by default the implementation of scaled dot product attention (SDPA) from | ||
| `torch.nn.functional.scaled_dot_product_attention`<https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html>_. | ||
| This operator is optimized for performance, but is still in beta and its results might be affected by numerical errors. | ||
| To use a pure python implementation, set the SDPA backend to "python" using the | ||
| :meth:`SDPA_backend` context manager. | ||
| """ | ||
|
|
||
| def __init__( | ||
|
|
@@ -682,6 +692,7 @@ def __init__( | |
| f"`out_channels` must be divisible by `num_heads`, but got {out_channels} and {num_heads}" | ||
| ) | ||
| self.num_heads = num_heads | ||
| self._sdpa_backend = None | ||
| self.norm = get_group_norm( | ||
| num_channels=out_channels, | ||
| eps=eps, | ||
|
|
@@ -705,29 +716,94 @@ def __init__( | |
| **init_zero, | ||
| ) | ||
|
|
||
| # Variables for selecting SDPA backend with cotext manager | ||
| _SDPA_BACKEND: ContextVar[Any] = ContextVar("sdpa_backend", default=None) | ||
|
|
||
| @classmethod | ||
| def _get_sdpa_backend( | ||
| cls, | ||
| ) -> None | SDPBackend | List[SDPBackend] | Literal["python"]: | ||
| return cls._SDPA_BACKEND.get() | ||
|
|
||
| @staticmethod | ||
| @contextlib.contextmanager | ||
| def SDPA_backend( | ||
| backend: None | SDPBackend | List[SDPBackend] | Literal["python"], | ||
| ) -> Generator[None, None, None]: | ||
| """ | ||
| Context manager to select the SDPA backend. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| backend : None | SDPBackend | List[SDPBackend] | Literal["python"] | ||
| - If ``None``, the default implementation based on | ||
| ``torch.nn.functional.scaled_dot_product_attention`` is used without | ||
| any specific backend. | ||
| - If ``"python"``, a custom python implementation of attention | ||
| based on :class:`~physicsnemo.models.diffusion.layers.AttentionOp` | ||
| is used. This backend is less performant but less sensitive to | ||
| numerical errors. | ||
| - In all other cases, the ``backend`` parameter is simply passed to | ||
| the | ||
| `torch.nn.attention.sdpa_kernel`<https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html>_ | ||
| context manager, which is used to select the backend to use in | ||
| ``torch.nn.functional.scaled_dot_product_attention``. | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> from physicsnemo.models.diffusion.layers import Attention | ||
| >>> from torch.nn.attention import SDPBackend | ||
| >>> import torch | ||
| >>> model = Attention(out_channels=16, num_heads=2) | ||
| >>> x = torch.randn(1, 16, 8, 8) | ||
| >>> # Default: use torch.nn.functional.scaled_dot_product_attention | ||
| >>> # without specific backend | ||
| >>> y0 = model(x) | ||
| >>> # Use custom python implementation of attention | ||
| >>> with Attention.SDPA_backend("python"): | ||
| ... y1 = model(x) | ||
| >>> # Use specific backend (C++ backend provided in pytorch) | ||
| >>> with Attention.SDPA_backend(SDPBackend.MATH): | ||
| ... y2 = model(x) | ||
| """ | ||
| token = Attention._SDPA_BACKEND.set(backend) | ||
| try: | ||
| if backend is None or backend == "python": | ||
| yield | ||
| else: | ||
| with sdpa_kernel(backend) as resource: | ||
| yield resource | ||
| finally: | ||
| Attention._SDPA_BACKEND.reset(token) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| x1: torch.Tensor = self.qkv(self.norm(x)) | ||
|
|
||
| # # NOTE: V1.0.1 implementation | ||
| # q, k, v = x1.reshape( | ||
| # x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1 | ||
| # ).unbind(2) | ||
| # w = AttentionOp.apply(q, k) | ||
| # attn = torch.einsum("nqk,nck->ncq", w, v) | ||
|
|
||
| q, k, v = ( | ||
| ( | ||
| x1.reshape( | ||
| x.shape[0], self.num_heads, x.shape[1] // self.num_heads, 3, -1 | ||
| # Get SDPA backend | ||
| _sdpa_backend = self.__class__._get_sdpa_backend() | ||
|
|
||
| # Custom python implementation (V1.0.1 implementation) | ||
| if _sdpa_backend == "python": | ||
| q, k, v = x1.reshape( | ||
| x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1 | ||
| ).unbind(2) | ||
| w = AttentionOp.apply(q, k) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know if this matters in your numerical stability issues. But the pytorch sdpa appears to use a scale factor: On the other hand, AttentionOp uses: Is the sequence length of q / k in these models always equal? |
||
| attn = torch.einsum("nqk,nck->ncq", w, v) | ||
| # Implementation based on torch.nn.functional.scaled_dot_product_attention | ||
| else: | ||
| q, k, v = ( | ||
| ( | ||
| x1.reshape( | ||
| x.shape[0], self.num_heads, x.shape[1] // self.num_heads, 3, -1 | ||
| ) | ||
| ) | ||
| .permute(0, 1, 4, 3, 2) | ||
| .unbind(-2) | ||
|
Comment on lines
+794
to
+801
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be more easy to maintain and follow the reshaping and permuting if we used einops here? |
||
| ) | ||
| .permute(0, 1, 4, 3, 2) | ||
| .unbind(-2) | ||
| ) | ||
| attn = torch.nn.functional.scaled_dot_product_attention( | ||
| q, k, v, scale=1 / math.sqrt(k.shape[-1]) | ||
| ) | ||
| attn = attn.transpose(-1, -2) | ||
| attn = torch.nn.functional.scaled_dot_product_attention( | ||
| q, k, v, scale=1 / math.sqrt(k.shape[-1]) | ||
| ) | ||
| attn = attn.transpose(-1, -2) | ||
|
|
||
| x: torch.Tensor = self.proj(attn.reshape(*x.shape)).add_(x) | ||
| return x | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A nit, but only because I've been working on docs so much lately: I think rst indentation expects 3, not 4, spaces in blocks like this.