Benchmark results are provided in the Jupyter notebooks alongside the kernel implementations.
| Dropout Benchmark | LayerNorm Benchmark |
|---|---|
![]() |
![]() |
The idea of optimizing these kernels is to perform as many operations on chip as possible, minimizing the need to write intermediate results back to DRAM.
For random dropout, we use tl.rand, which uses the Philox RNG to generate a random number for each thread on the fly, without the need of reading in a random mask from DRAM.
For layer normalization, in the forward pass, the calculations for mean and variance are performed on-chip. For the backward pass (see the math shown here), we divide the computation into two stages. In the first stage, we compute DX, DW, and DB for each row. Since DW and DB need to be accumulated across the batch, we first reduce them into GROUP_SIZE_M partial sums (using spinlocks, and they all should be retained in L1/L2 cache). In the second stage, we further reduce these partial sums. This approach is faster because the reduction is performed on registers (warp reduce).
GROUP_SIZE_M is a key knob. If it is too small, it would lead to too many partials and thus more work for stage 2. If it is too big, there would be more contention and higher L2 pressure
LayerNorm is calculated for vectors, broadcasted on batches. For each vector, it is
For back prop, recall
with
and
therefore, for each input vector,
notice there are values that can be shared, if we let
then
where
Gradients for learnable weights and biases are easy to calculate:

