Skip to content

Conversation

@CharlelieLrt
Copy link
Collaborator

PhysicsNeMo Pull Request

Description

Context

PR #954 changed the computation of the Attention forward defined in physicsnemo/models/diffusion/layers.py such that it is now based on torch.nn.scaled_dot_product_attention (instead of the former custom python implementation). This Pytorch API offers improved performance, but it is still in beta and is known to be hardware dependent and sensitive to numerical errors.

Changes

The present PR re-introduces the former attention computation, based on a custom python implementation, as an option. In comparison to torch.nn.scaled_dot_product_attention, this implementation offers worse performance, but better stability and sensitivity to numerical errors. The default forward pass of the Attention class remains based on torch.nn.scaled_dot_product_attention, but the custom python implementation can now be set with a context manager:

from physicsnemo.models.diffusion.layers import Attention
from torch.nn.attention import SDPBackend

# Default: use torch.nn.functional.scaled_dot_product_attention
# without specific backend
y = model(x)

# Use custom python implementation of attention
with Attention.SDPA_backend("python"):
    y = model(x)
    
# Use specific pytorch backend
# for torch.nn.functional.scaled_dot_product_attention
with Attention.SDPA_backend(SDPBackend.FLASH_ATTENTION):
    y = model(x)

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

@CharlelieLrt
Copy link
Collaborator Author

/blossom-ci

Copy link
Collaborator

@coreyjadams coreyjadams left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to hold up your review @CharlelieLrt . But a question - pytorch already supports a context manager to change the behavior of attention layers. In particular, they offer a pytorch c++ implementation that claims to provide the "classic" style implementation, but in a C++ one-layer op.

torch.backends.cuda.enable_math_sdp()

Have you tried the sdpa with that backend to see if you get better numerical agreement?

Additionally, regardless of the results, could you add some details on the level of instabilities you're seeing? I think it will help us in the future to have it saved somewhere saying "these layers diverged by X amount on CPU vs GPU0 vs GPU1", etc.

Comment on lines +662 to +668
.. important::
This implementation uses by default the implementation of scaled dot product attention (SDPA) from
`torch.nn.functional.scaled_dot_product_attention`<https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html>_.
This operator is optimized for performance, but is still in beta and its results might be affected by numerical errors.
To use a pure python implementation, set the SDPA backend to "python" using the
:meth:`SDPA_backend` context manager.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A nit, but only because I've been working on docs so much lately: I think rst indentation expects 3, not 4, spaces in blocks like this.

Comment on lines +794 to +801
q, k, v = (
(
x1.reshape(
x.shape[0], self.num_heads, x.shape[1] // self.num_heads, 3, -1
)
)
.permute(0, 1, 4, 3, 2)
.unbind(-2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be more easy to maintain and follow the reshaping and permuting if we used einops here?

q, k, v = x1.reshape(
x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1
).unbind(2)
w = AttentionOp.apply(q, k)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this matters in your numerical stability issues. But the pytorch sdpa appears to use a scale factor:

scale_factor = 1 / math.sqrt(query.size(1)) if scale is None else scale

On the other hand, AttentionOp uses:

(k / torch.sqrt(torch.tensor(k.shape[1]))).to(torch.float32)

Is the sequence length of q / k in these models always equal?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants