Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
`Module.from_checkpoint` that now exposes a `strict` parameter to raise error
on missing/unexpected keys, similar to that used in
`torch.nn.Module.load_state_dict`.

### Deprecated

### Removed
- Diffusion models: adds a context manager `Attention.SDPA_backend` to control
the backend used for the attention mechanism defined in
`physicsnemo.models.diffusion.layers.Attention`. Gives the possibility to
select attention computation based on `torch.nn.scaled_dot_product_attention`
or on a custom python implementation.

### Fixed

Expand Down
1 change: 1 addition & 0 deletions physicsnemo/models/diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Linear,
PositionalEmbedding,
UNetBlock,
Attention,
)
from .song_unet import SongUNet, SongUNetPosEmbd, SongUNetPosLtEmbd
from .dhariwal_unet import DhariwalUNet
Expand Down
114 changes: 95 additions & 19 deletions physicsnemo/models/diffusion/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
import contextlib
import importlib
import math
from typing import Any, Dict, List, Set
from contextvars import ContextVar
from typing import Any, Dict, Generator, List, Literal, Set

import numpy as np
import nvtx
import torch
from einops import rearrange
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.functional import elu, gelu, leaky_relu, relu, sigmoid, silu, tanh

from physicsnemo.models.diffusion import weight_init
Expand Down Expand Up @@ -656,6 +658,14 @@ class Attention(torch.nn.Module):
-------
torch.Tensor
Output tensor of the same shape as input: :math:`(B, C, H, W)`.

.. important::

This implementation uses by default the implementation of scaled dot product attention (SDPA) from
`torch.nn.functional.scaled_dot_product_attention`<https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html>_.
This operator is optimized for performance, but is still in beta and its results might be affected by numerical errors.
To use a pure python implementation, set the SDPA backend to "python" using the
:meth:`SDPA_backend` context manager.
Comment on lines +662 to +668
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A nit, but only because I've been working on docs so much lately: I think rst indentation expects 3, not 4, spaces in blocks like this.

"""

def __init__(
Expand All @@ -682,6 +692,7 @@ def __init__(
f"`out_channels` must be divisible by `num_heads`, but got {out_channels} and {num_heads}"
)
self.num_heads = num_heads
self._sdpa_backend = None
self.norm = get_group_norm(
num_channels=out_channels,
eps=eps,
Expand All @@ -705,29 +716,94 @@ def __init__(
**init_zero,
)

# Variables for selecting SDPA backend with cotext manager
_SDPA_BACKEND: ContextVar[Any] = ContextVar("sdpa_backend", default=None)

@classmethod
def _get_sdpa_backend(
cls,
) -> None | SDPBackend | List[SDPBackend] | Literal["python"]:
return cls._SDPA_BACKEND.get()

@staticmethod
@contextlib.contextmanager
def SDPA_backend(
backend: None | SDPBackend | List[SDPBackend] | Literal["python"],
) -> Generator[None, None, None]:
"""
Context manager to select the SDPA backend.

Parameters
----------
backend : None | SDPBackend | List[SDPBackend] | Literal["python"]
- If ``None``, the default implementation based on
``torch.nn.functional.scaled_dot_product_attention`` is used without
any specific backend.
- If ``"python"``, a custom python implementation of attention
based on :class:`~physicsnemo.models.diffusion.layers.AttentionOp`
is used. This backend is less performant but less sensitive to
numerical errors.
- In all other cases, the ``backend`` parameter is simply passed to
the
`torch.nn.attention.sdpa_kernel`<https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html>_
context manager, which is used to select the backend to use in
``torch.nn.functional.scaled_dot_product_attention``.

Examples
--------
>>> from physicsnemo.models.diffusion.layers import Attention
>>> from torch.nn.attention import SDPBackend
>>> import torch
>>> model = Attention(out_channels=16, num_heads=2)
>>> x = torch.randn(1, 16, 8, 8)
>>> # Default: use torch.nn.functional.scaled_dot_product_attention
>>> # without specific backend
>>> y0 = model(x)
>>> # Use custom python implementation of attention
>>> with Attention.SDPA_backend("python"):
... y1 = model(x)
>>> # Use specific backend (C++ backend provided in pytorch)
>>> with Attention.SDPA_backend(SDPBackend.MATH):
... y2 = model(x)
"""
token = Attention._SDPA_BACKEND.set(backend)
try:
if backend is None or backend == "python":
yield
else:
with sdpa_kernel(backend) as resource:
yield resource
finally:
Attention._SDPA_BACKEND.reset(token)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x1: torch.Tensor = self.qkv(self.norm(x))

# # NOTE: V1.0.1 implementation
# q, k, v = x1.reshape(
# x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1
# ).unbind(2)
# w = AttentionOp.apply(q, k)
# attn = torch.einsum("nqk,nck->ncq", w, v)

q, k, v = (
(
x1.reshape(
x.shape[0], self.num_heads, x.shape[1] // self.num_heads, 3, -1
# Get SDPA backend
_sdpa_backend = self.__class__._get_sdpa_backend()

# Custom python implementation (V1.0.1 implementation)
if _sdpa_backend == "python":
q, k, v = x1.reshape(
x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1
).unbind(2)
w = AttentionOp.apply(q, k)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this matters in your numerical stability issues. But the pytorch sdpa appears to use a scale factor:

scale_factor = 1 / math.sqrt(query.size(1)) if scale is None else scale

On the other hand, AttentionOp uses:

(k / torch.sqrt(torch.tensor(k.shape[1]))).to(torch.float32)

Is the sequence length of q / k in these models always equal?

attn = torch.einsum("nqk,nck->ncq", w, v)
# Implementation based on torch.nn.functional.scaled_dot_product_attention
else:
q, k, v = (
(
x1.reshape(
x.shape[0], self.num_heads, x.shape[1] // self.num_heads, 3, -1
)
)
.permute(0, 1, 4, 3, 2)
.unbind(-2)
Comment on lines +794 to +801
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be more easy to maintain and follow the reshaping and permuting if we used einops here?

)
.permute(0, 1, 4, 3, 2)
.unbind(-2)
)
attn = torch.nn.functional.scaled_dot_product_attention(
q, k, v, scale=1 / math.sqrt(k.shape[-1])
)
attn = attn.transpose(-1, -2)
attn = torch.nn.functional.scaled_dot_product_attention(
q, k, v, scale=1 / math.sqrt(k.shape[-1])
)
attn = attn.transpose(-1, -2)

x: torch.Tensor = self.proj(attn.reshape(*x.shape)).add_(x)
return x
Expand Down