-
Notifications
You must be signed in to change notification settings - Fork 473
Added backend context manager to select SDPA implementation #1061
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Charlelie Laurent <[email protected]>
|
/blossom-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to hold up your review @CharlelieLrt . But a question - pytorch already supports a context manager to change the behavior of attention layers. In particular, they offer a pytorch c++ implementation that claims to provide the "classic" style implementation, but in a C++ one-layer op.
torch.backends.cuda.enable_math_sdp()
Have you tried the sdpa with that backend to see if you get better numerical agreement?
Additionally, regardless of the results, could you add some details on the level of instabilities you're seeing? I think it will help us in the future to have it saved somewhere saying "these layers diverged by X amount on CPU vs GPU0 vs GPU1", etc.
| .. important:: | ||
| This implementation uses by default the implementation of scaled dot product attention (SDPA) from | ||
| `torch.nn.functional.scaled_dot_product_attention`<https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html>_. | ||
| This operator is optimized for performance, but is still in beta and its results might be affected by numerical errors. | ||
| To use a pure python implementation, set the SDPA backend to "python" using the | ||
| :meth:`SDPA_backend` context manager. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A nit, but only because I've been working on docs so much lately: I think rst indentation expects 3, not 4, spaces in blocks like this.
| q, k, v = ( | ||
| ( | ||
| x1.reshape( | ||
| x.shape[0], self.num_heads, x.shape[1] // self.num_heads, 3, -1 | ||
| ) | ||
| ) | ||
| .permute(0, 1, 4, 3, 2) | ||
| .unbind(-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be more easy to maintain and follow the reshaping and permuting if we used einops here?
| q, k, v = x1.reshape( | ||
| x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1 | ||
| ).unbind(2) | ||
| w = AttentionOp.apply(q, k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if this matters in your numerical stability issues. But the pytorch sdpa appears to use a scale factor:
scale_factor = 1 / math.sqrt(query.size(1)) if scale is None else scale
On the other hand, AttentionOp uses:
(k / torch.sqrt(torch.tensor(k.shape[1]))).to(torch.float32)
Is the sequence length of q / k in these models always equal?
PhysicsNeMo Pull Request
Description
Context
PR #954 changed the computation of the
Attentionforward defined inphysicsnemo/models/diffusion/layers.pysuch that it is now based ontorch.nn.scaled_dot_product_attention(instead of the former custom python implementation). This Pytorch API offers improved performance, but it is still in beta and is known to be hardware dependent and sensitive to numerical errors.Changes
The present PR re-introduces the former attention computation, based on a custom python implementation, as an option. In comparison to
torch.nn.scaled_dot_product_attention, this implementation offers worse performance, but better stability and sensitivity to numerical errors. The default forward pass of theAttentionclass remains based ontorch.nn.scaled_dot_product_attention, but the custom python implementation can now be set with a context manager:Checklist
Dependencies