1616#include < fmha/traits.h>
1717#include < fmha/utils.h>
1818
19+ #include " fmha/hopper/arrive_wait.h"
20+
1921namespace fmha {
2022namespace ws {
2123
@@ -71,6 +73,9 @@ struct Softmax_base {
7173 // Whether we need to check if local_max could be -inf or not.
7274 enum { CHECK_IF_NEG_INF_EXISTS = SLIDING_OR_CHUNKED_ATTENTION || USE_CUSTOM_MASK };
7375
76+ // There are 2 warpgroups so 0x3 and 0x4 are used
77+ enum { SKIP_SOFTMAX_BARRIER = Kernel_traits::SKIP_SOFTMAX_BARRIER_ID };
78+
7479 // Ctor.
7580 template <typename Params>
7681 inline __device__ Softmax_base (Params params, int tidx)
@@ -80,7 +85,12 @@ struct Softmax_base {
8085 sliding_window_size_(params.sliding_window_size),
8186 log2_chunked_attention_size_(params.log2_chunked_attention_size),
8287 packed_mask_ptr_{reinterpret_cast <uint32_t *>(params.packed_mask_ptr )},
83- params_packed_mask_stride_in_bytes_{params.packed_mask_stride_in_bytes } {
88+ params_packed_mask_stride_in_bytes_{params.packed_mask_stride_in_bytes },
89+ #ifdef SKIP_SOFTMAX_STAT
90+ total_blocks (0 ),
91+ skipped_blocks(0 ),
92+ #endif
93+ skip_softmax_threshold (0 ) {
8494 int warp = tidx / 32 ;
8595 int lane = tidx % 32 ;
8696 // The corresponding row/col for each thread after MMA.
@@ -253,25 +263,67 @@ struct Softmax_base {
253263 }
254264
255265 // Calculate max/sum, and update flash-attention scales.
266+ // Returns false if skipped due to skip-softmax attention feature.
256267 template <bool IS_FIRST_COL>
257- inline __device__ void compute_and_update_scale (float (&global_max)[Mma_tile_p::CORES_M],
258- float (&global_sum)[Mma_tile_p::CORES_M]) {
268+ inline __device__ bool compute_and_update_scale (float (&global_max)[Mma_tile_p::CORES_M],
269+ float (&global_sum)[Mma_tile_p::CORES_M],
270+ uint32_t* skip_softmax_vote) {
259271 float const scale = reinterpret_cast <float const &>(scale_bmm1_);
260272
273+ // whether this warpgroup skips the softmax
274+ constexpr bool may_skip = Kernel_traits::ENABLE_SKIP_SOFTMAX && !IS_FIRST_COL;
275+ bool skip = may_skip;
276+
261277// Row-wise max of current tile.
262278#pragma unroll
263279 for (int mi = 0 ; mi < Mma_tile_p::CORES_M; mi++) {
264- if (IS_FIRST_COL) {
265- local_max_[mi] = elt_[mi][0 ];
266- } else {
267- local_max_[mi] = fmaxf (global_max[mi], elt_[mi][0 ]);
268- }
280+ local_max_[mi] = elt_[mi][0 ];
269281#pragma unroll
270282 for (int ni = 1 ; ni < Mma_tile_p::CORES_N * 2 ; ni++) {
271283 local_max_[mi] = fmaxf (local_max_[mi], elt_[mi][ni]);
272284 }
273285 local_max_[mi] = fmaxf (__shfl_xor_sync (uint32_t (-1 ), local_max_[mi], 1 ), local_max_[mi]);
274286 local_max_[mi] = fmaxf (__shfl_xor_sync (uint32_t (-1 ), local_max_[mi], 2 ), local_max_[mi]);
287+
288+ if constexpr (may_skip) {
289+ // AND(&) the CORES_M results, then `skip` means whether to skip
290+ // the CORES_M(=2) rows
291+ if constexpr (!EXP2F_OPTIMIZATION) {
292+ skip &= expf (local_max_[mi] - global_max[mi]) < skip_softmax_threshold;
293+ } else {
294+ skip &= exp2f ((local_max_[mi] - global_max[mi]) * scale) < skip_softmax_threshold;
295+ }
296+ }
297+
298+ if (!IS_FIRST_COL) {
299+ local_max_[mi] = fmaxf (local_max_[mi], global_max[mi]);
300+ }
301+ }
302+
303+ if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX) {
304+ #ifdef SKIP_SOFTMAX_STAT
305+ total_blocks++;
306+ #endif
307+ if constexpr (may_skip) {
308+ // AND(&) the results together in a warp, then `skip` means whether to skip
309+ // all the 16 rows managed by this warp.
310+ // each 4 threads (e.g. T0~T3) have the same `skip`, only 0x11111111 is needed
311+ // instead of 0xffffffff. But the perf is the same.
312+ skip = __all_sync (0xffffffff , skip);
313+ if (threadIdx.x % 32 == 0 ) {
314+ // The leader of each warp votes.
315+ atomicAnd (skip_softmax_vote, uint32_t (skip));
316+ }
317+ // WG0 uses 0x3 barrier, WG1 uses 0x4 barrier
318+ named_barrier_wait (SKIP_SOFTMAX_BARRIER + threadIdx.x / 128 , 128 );
319+ skip = *((uint32_t volatile *)skip_softmax_vote);
320+ if (skip) {
321+ #ifdef SKIP_SOFTMAX_STAT
322+ skipped_blocks++;
323+ #endif
324+ return false ;
325+ }
326+ }
275327 }
276328
277329// Softmax Exp.
@@ -339,6 +391,7 @@ struct Softmax_base {
339391 global_max[mi] = max_new;
340392 }
341393 }
394+ return true ;
342395 }
343396
344397 // Update flash attention scales and pack elements for BMM2.
@@ -407,6 +460,13 @@ struct Softmax_base {
407460 float correction_[Mma_tile_p::CORES_M];
408461 // The packed mask.
409462 uint4 packed_mask_;
463+ // Skip softmax when exp(local_max - global_max) < skip_softmax_threshold.
464+ float skip_softmax_threshold;
465+ #ifdef SKIP_SOFTMAX_STAT
466+ // Statistics of skip-softmax
467+ uint32_t total_blocks;
468+ uint32_t skipped_blocks;
469+ #endif
410470};
411471
412472// //////////////////////////////////////////////////////////////////////////////////////////////////
@@ -676,29 +736,72 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
676736 inline __device__ Softmax (Params const & params, int tidx) : Base(params, tidx) {}
677737
678738 // Calculate max/sum, and update flash-attention scales.
739+ // Returns false if skipped due to skip-softmax attention feature.
679740 template <bool IS_FIRST_COL>
680- inline __device__ void compute_and_update_scale (float (&global_max)[Mma_tile_p::CORES_M],
681- float (&global_sum)[Mma_tile_p::CORES_M]) {
741+ inline __device__ bool compute_and_update_scale (float (&global_max)[Mma_tile_p::CORES_M],
742+ float (&global_sum)[Mma_tile_p::CORES_M],
743+ uint32_t* skip_softmax_vote) {
682744 float const scale = reinterpret_cast <float const &>(this ->scale_bmm1_ );
683745 float (&local_max_)[Mma_tile_p::CORES_M] = this ->local_max_ ;
684746 float (&local_sum_)[Mma_tile_p::CORES_M] = this ->local_sum_ ;
685747 float (&correction_)[Mma_tile_p::CORES_M] = this ->correction_ ;
686748 float (&elt_)[Mma_tile_p::CORES_M][Mma_tile_p::CORES_N * 2 ] = this ->elt_ ;
687749
750+ // whether this warpgroup skips the softmax
751+ constexpr bool may_skip = Kernel_traits::ENABLE_SKIP_SOFTMAX && !IS_FIRST_COL;
752+ bool skip = may_skip;
753+
688754// Row-wise max of current tile.
689755#pragma unroll
690756 for (int mi = 0 ; mi < Mma_tile_p::CORES_M; mi++) {
691- if (IS_FIRST_COL) {
692- local_max_[mi] = elt_[mi][0 ];
693- } else {
694- local_max_[mi] = fmaxf (global_max[mi], elt_[mi][0 ]);
695- }
757+ local_max_[mi] = elt_[mi][0 ];
696758#pragma unroll
697759 for (int ni = 1 ; ni < Mma_tile_p::CORES_N * 2 ; ni++) {
698760 local_max_[mi] = fmaxf (local_max_[mi], elt_[mi][ni]);
699761 }
700762 local_max_[mi] = fmaxf (__shfl_xor_sync (uint32_t (-1 ), local_max_[mi], 1 ), local_max_[mi]);
701763 local_max_[mi] = fmaxf (__shfl_xor_sync (uint32_t (-1 ), local_max_[mi], 2 ), local_max_[mi]);
764+ // AND(&) the CORES_M results, then `skip` means whether to skip
765+ // the CORES_M(=2) rows
766+ if constexpr (may_skip) {
767+ // AND(&) the CORES_M results, then `skip` means whether to skip
768+ // the CORES_M(=2) rows
769+ if constexpr (!EXP2F_OPTIMIZATION) {
770+ skip &= expf (local_max_[mi] - global_max[mi]) < this ->skip_softmax_threshold ;
771+ } else {
772+ skip &= exp2f ((local_max_[mi] - global_max[mi]) * scale) < this ->skip_softmax_threshold ;
773+ }
774+ }
775+ if (!IS_FIRST_COL) {
776+ local_max_[mi] = fmaxf (local_max_[mi], global_max[mi]);
777+ }
778+ }
779+
780+ if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX) {
781+ #ifdef SKIP_SOFTMAX_STAT
782+ this ->total_blocks ++;
783+ #endif
784+
785+ if constexpr (may_skip) {
786+ // AND(&) the results together in a warp, then `skip` means whether to skip
787+ // all the 16 rows managed by this warp.
788+ // each 4 threads (e.g. T0~T3) have the same `skip`, only 0x11111111 is needed
789+ // instead of 0xffffffff. But the perf is the same.
790+ skip = __all_sync (0xffffffff , skip);
791+ if (threadIdx.x % 32 == 0 ) {
792+ // The leader of each warp votes.
793+ atomicAnd (skip_softmax_vote, uint32_t (skip));
794+ }
795+ // WG0 uses 0x3 barrier, WG1 uses 0x4 barrier
796+ named_barrier_wait (Base::SKIP_SOFTMAX_BARRIER + threadIdx.x / 128 , 128 );
797+ skip = *((uint32_t volatile *)skip_softmax_vote);
798+ if (skip) {
799+ #ifdef SKIP_SOFTMAX_STAT
800+ this ->skipped_blocks ++;
801+ #endif
802+ return false ;
803+ }
804+ }
702805 }
703806
704807// Softmax Exp.
@@ -774,6 +877,7 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
774877 global_max[mi] = max_new;
775878 }
776879 }
880+ return true ;
777881 }
778882
779883 // Update flash attention scales and pack elements for BMM2.
0 commit comments