-
Notifications
You must be signed in to change notification settings - Fork 55
Description
back in #68, we tackled jvp for natten2dqk and natten2dav. for use-cases such as:
- training consistency models or Nvidia GENIE heads
- a diffusion model typically learns to output a value (e.g. velocity, v) from which we can compute the first derivative of the reverse diffusion ODE. but with forward-mode autodiff, we can compute the second derivative of that ODE, and moreover train the model to drive that second derivative to 1, finding straight-trajectory solutions which can be Euler-integrated accurately in fewer steps
- at inference-time, using a Taylor-method sampler to compute the first and second derivatives of the ODE, to obtain a more accurate trajectory which can turn corners better
back then we just looked at the matmuls, for which the jvp derivation was understood. and since we were using unfused APIs, we were able to just use whatever pytorch provides for softmax.
after that, NATTEN introduced the fused kernel. no derivation was known (to us) for the jvp of an attention which employed an online softmax.
but then Cheng Lu joined OpenAI and, with Yang Song, wrote a new consistency training paper:
Simplifying, Stabilizing and Scaling Continuous-Time Consistency Models
within which (Appendix F) a derivation for "jvp of flash attention" is shared:
discussions are still early raising visibility with the flash attention team (Dao-AILab/flash-attention#1672) (I regret not mentioning it in October!). I think the flex attention team are also aware (https://discord.com/channels/729741769192767510/1079865324087803985/1296191340870238208) but I'm not sure whether @Chillee still works on pytorch.
consequently, the means to implement a memory-efficient jvp is known, but I believe there is no public implementation!
I say memory-efficient instead of flash or hardware-aware, because I believe the significant part of the derivation is online softmax, which enables fusion with QK similarity.
there are more papers on few-step image generation coming out, relying on jvp of attention:
https://arxiv.org/abs/2505.13447
and practitioners are even implementing such papers (albeit substituting jvp with an approximation via finite difference) at small-scale:
https://x.com/LodestoneE621/status/1926520450068124037
so, there's an opportunity to be the first attention kernel to support memory-efficient jvp!
in fact, since any-sized kernels are supported: a jvp NATTEN could also be used as a cheeky way to access jvp for global self-attention, if kernel size == canvas size. well, for people who train on squares (sufficient for research, though not for production models).
I'll also add that I'm just raising this out of completeness and don't have any current plans involving this functionality. but yeah, the algorithm is known so I figured I'd increase awareness.