You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am developing intuition for the heuristic used to compute the number of warps.
The _layer_norm_fwd_fused kernel uses BLOCK_SIZE for accumulation in for loops over an input vector. But the forward method of the LayerNorm class sets BLOCK_SIZE to the next power of two of the input vector dimensionality. There is one iteration in the _layer_norm_fwd_fused kernel and no need for accumulation.
There is also no accumulation in the _layer_norm_bwd_dx_fused kernel that takes BLOCK_SIZE value from the context set in the forward pass, suggesting that the for loops are not used for compiler prompting. Because BLOCK_SIZE_N in the _layer_norm_bwd_dx_fused kernel cannot be less than the input vector dimensionality, the use of the for loops in the _layer_norm_fwd_fused kernel can lead to potential bugs.
An instance of the _layer_norm_{fwd_fused, bwd_dx_fused} kernel processes one entire input vector and uses the number of warps according to the heuristic. Upto 8 warps are used resulting in upto 256 threads. To get a number of warps other than 8, BLOCK_SIZE is divided by 256. This suggests that i) 256 contiguous 16-bit elements are accessed by 32 threads in a warp, and ii) a thread accesses 8 such elements in a single 128-bit vectorized load/store transaction.
In the _layer_norm_bwd_dwdb kernel, the accesses to partial weight and bias gradients are in contiguous row segments of 128 16-bit elements. One half of a warp would access a contiguous row segment with one 128-bit vectorized transaction per thread, and another half of the warp would similarly access another non-adjacent contiguous row segment. Note that the default number of warps appears to be used here, in contrast to the layer_norm_{fwd_fused, bwd_dx_fused} kernels.
Based on this analysis, the heuristic for computing the number of warps uses the following criteria:
one 128-bit vectorized load/store transaction per thread per data block,
threads in a warp access one contiguous segment, or two non-adjacent contiguous segments, and
a thread block preferably has upto 256 threads; 256 threads may be suitable for high occupancy across NVIDIA architectures.
I also assume that the .sum and += accumulations are automatically optimized by the compiler into parallel scans with O(log N) step complexity.
The purpose of the for loops in the layer_norm_fwd_fused kernel remains unclear. Any comments regarding the heuristic are also appreciated. Thank you.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
I am developing intuition for the heuristic used to compute the number of warps.
The
_layer_norm_fwd_fusedkernel usesBLOCK_SIZEfor accumulation in for loops over an input vector. But theforwardmethod of theLayerNormclass setsBLOCK_SIZEto the next power of two of the input vector dimensionality. There is one iteration in the_layer_norm_fwd_fusedkernel and no need for accumulation.There is also no accumulation in the
_layer_norm_bwd_dx_fusedkernel that takesBLOCK_SIZEvalue from the context set in the forward pass, suggesting that the for loops are not used for compiler prompting. BecauseBLOCK_SIZE_Nin the_layer_norm_bwd_dx_fusedkernel cannot be less than the input vector dimensionality, the use of the for loops in the_layer_norm_fwd_fusedkernel can lead to potential bugs.An instance of the
_layer_norm_{fwd_fused, bwd_dx_fused}kernel processes one entire input vector and uses the number of warps according to the heuristic. Upto 8 warps are used resulting in upto 256 threads. To get a number of warps other than 8,BLOCK_SIZEis divided by 256. This suggests that i) 256 contiguous 16-bit elements are accessed by 32 threads in a warp, and ii) a thread accesses 8 such elements in a single 128-bit vectorized load/store transaction.In the
_layer_norm_bwd_dwdbkernel, the accesses to partial weight and bias gradients are in contiguous row segments of 128 16-bit elements. One half of a warp would access a contiguous row segment with one 128-bit vectorized transaction per thread, and another half of the warp would similarly access another non-adjacent contiguous row segment. Note that the default number of warps appears to be used here, in contrast to thelayer_norm_{fwd_fused, bwd_dx_fused}kernels.Based on this analysis, the heuristic for computing the number of warps uses the following criteria:
I also assume that the
.sumand+=accumulations are automatically optimized by the compiler into parallel scans with O(log N) step complexity.The purpose of the for loops in the
layer_norm_fwd_fusedkernel remains unclear. Any comments regarding the heuristic are also appreciated. Thank you.Beta Was this translation helpful? Give feedback.
All reactions