Skip to content

Forward-mode autodiff for fused kernel #224

@Birch-san

Description

@Birch-san

back in #68, we tackled jvp for natten2dqk and natten2dav. for use-cases such as:

  • training consistency models or Nvidia GENIE heads
    • a diffusion model typically learns to output a value (e.g. velocity, v) from which we can compute the first derivative of the reverse diffusion ODE. but with forward-mode autodiff, we can compute the second derivative of that ODE, and moreover train the model to drive that second derivative to 1, finding straight-trajectory solutions which can be Euler-integrated accurately in fewer steps
  • at inference-time, using a Taylor-method sampler to compute the first and second derivatives of the ODE, to obtain a more accurate trajectory which can turn corners better

back then we just looked at the matmuls, for which the jvp derivation was understood. and since we were using unfused APIs, we were able to just use whatever pytorch provides for softmax.

after that, NATTEN introduced the fused kernel. no derivation was known (to us) for the jvp of an attention which employed an online softmax.

but then Cheng Lu joined OpenAI and, with Yang Song, wrote a new consistency training paper:
Simplifying, Stabilizing and Scaling Continuous-Time Consistency Models

within which (Appendix F) a derivation for "jvp of flash attention" is shared:

Image Image

discussions are still early raising visibility with the flash attention team (Dao-AILab/flash-attention#1672) (I regret not mentioning it in October!). I think the flex attention team are also aware (https://discord.com/channels/729741769192767510/1079865324087803985/1296191340870238208) but I'm not sure whether @Chillee still works on pytorch.

consequently, the means to implement a memory-efficient jvp is known, but I believe there is no public implementation!
I say memory-efficient instead of flash or hardware-aware, because I believe the significant part of the derivation is online softmax, which enables fusion with QK similarity.

there are more papers on few-step image generation coming out, relying on jvp of attention:
https://arxiv.org/abs/2505.13447

and practitioners are even implementing such papers (albeit substituting jvp with an approximation via finite difference) at small-scale:
https://x.com/LodestoneE621/status/1926520450068124037

so, there's an opportunity to be the first attention kernel to support memory-efficient jvp!
in fact, since any-sized kernels are supported: a jvp NATTEN could also be used as a cheeky way to access jvp for global self-attention, if kernel size == canvas size. well, for people who train on squares (sufficient for research, though not for production models).

I'll also add that I'm just raising this out of completeness and don't have any current plans involving this functionality. but yeah, the algorithm is known so I figured I'd increase awareness.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions