Skip to content

Commit 847af74

Browse files
authored
Merge pull request #1 from zhou-yuxin/add-skip-softmax
Add skip softmax
2 parents 2ea7730 + 0e5974f commit 847af74

10 files changed

Lines changed: 411 additions & 95 deletions

File tree

csrc/fmha_v2/fmha/warpspec/compute.h

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ struct Compute {
179179
USE_CUSTOM_MASK ? (head_info.mask_sum_s + q_step_idx * STEP_Q + local_q_tile_offset) \
180180
: (q_step_idx * STEP_Q + head_info.q_tile_offset), \
181181
kv_step_idx * STEP_KV, sage_scale_row, cbr, cbr_v, mutex_accessor, \
182-
kv_step_idx == kv_idx_end - 1);
182+
&shared->skip_softmax_votes[kv_step_idx & 1][warpgroup_id], kv_step_idx == kv_idx_end - 1);
183183

184184
////////////////////////////////////////////////////////////////////////////////////////////////
185185

@@ -277,6 +277,12 @@ struct Compute {
277277
int const actual_kv_seqlen =
278278
SEPARATE_Q_KV_BUFFER ? head_info.actual_kv_seqlen : actual_q_seqlen;
279279

280+
// Update threshold of Skip-Softmax
281+
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX) {
282+
softmax.skip_softmax_threshold =
283+
params.skip_softmax_threshold_scale_factor / actual_kv_seqlen;
284+
}
285+
280286
// Calculate the alibi head_scaling_factor.
281287
float alibi_head_scale = APPLY_ALIBI ? get_alibi_head_scaling_factor<AlibiParams>(
282288
head_info.bidh, params.alibi_params)
@@ -411,6 +417,12 @@ struct Compute {
411417
}
412418
}
413419
}
420+
#ifdef SKIP_SOFTMAX_STAT
421+
if (tidx == 0) {
422+
atomicAdd(params.skip_softmax_total_blocks, softmax.total_blocks);
423+
atomicAdd(params.skip_softmax_skipped_blocks, softmax.skipped_blocks);
424+
}
425+
#endif
414426
}
415427

416428
////////////////////////////////////////////////////////////////////////////////////////////////
@@ -421,7 +433,14 @@ struct Compute {
421433
float (&p_max)[Mma_tile_p::CORES_M], float (&p_sum)[Mma_tile_p::CORES_M], int const tidx,
422434
int const actual_kv_seqlen, float const alibi_head_scale, int const row_offset,
423435
int const col_offset, int const sage_scale_row, Circular_buffer_q_reader& cbr,
424-
Circular_buffer_kv_reader& cbr_v, OrderedMutexAccessor& mutex, bool complete = false) {
436+
Circular_buffer_kv_reader& cbr_v, OrderedMutexAccessor& mutex, uint32_t* skip_softmax_vote,
437+
bool complete = false) {
438+
// Skip-softmax vote initialization
439+
if (tidx == 0) {
440+
// Note that we need a named_barrier_wait in compute_single_tile to make sure init is before
441+
// voting.
442+
*skip_softmax_vote = 1;
443+
}
425444
// load the scales of K/V from global memory
426445
#define LOAD_SCALES_KV(dst, which, blocks_per_step, block_size) \
427446
if constexpr (block_size > 0) { \
@@ -453,6 +472,10 @@ struct Compute {
453472
// Ctile_p is only used once by each n step.
454473
ctile_p.clear();
455474

475+
// If skip_softmax is enabled, make sure there is no racing between the initialization and
476+
// writing of skip_softmax_vote.
477+
named_barrier_wait(Kernel_traits::SKIP_SOFTMAX_BARRIER_ID + threadIdx.x / 128, 128);
478+
456479
// BMM1 (Q x K').
457480
warpgroup_arrive();
458481

@@ -513,8 +536,19 @@ struct Compute {
513536
softmax.apply_alibi_and_mask<APPLY_MASK>(ctile_p, params.alibi_params, alibi_head_scale,
514537
actual_kv_seqlen, row_offset, col_offset);
515538

516-
// Softmax Exp, max/sum, and update scales.
517-
softmax.compute_and_update_scale<IS_FIRST_COL>(p_max, p_sum);
539+
// Softmax Exp, max/sum, and update scales. If returns false we skip the rest.
540+
if (!softmax.compute_and_update_scale<IS_FIRST_COL>(p_max, p_sum, skip_softmax_vote)) {
541+
if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 1) {
542+
// Notify another warpgroup to execute QGMMA.
543+
mutex.named_bar_arrive();
544+
}
545+
// Need to wait V, otherwise compute-sanitizer synccheck will fail.
546+
int ready2 = cbr_v.peek();
547+
if (!ready2) {
548+
cbr_v.wait();
549+
}
550+
return;
551+
}
518552

519553
// experiments show that here is the best place to load scales of V
520554
float scales_v[SAGE_BLOCKS_PER_STEP_V];

csrc/fmha_v2/fmha/warpspec/epilogue.h

Lines changed: 119 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include <fmha/traits.h>
1717
#include <fmha/utils.h>
1818

19+
#include "fmha/hopper/arrive_wait.h"
20+
1921
namespace fmha {
2022
namespace ws {
2123

@@ -71,6 +73,9 @@ struct Softmax_base {
7173
// Whether we need to check if local_max could be -inf or not.
7274
enum { CHECK_IF_NEG_INF_EXISTS = SLIDING_OR_CHUNKED_ATTENTION || USE_CUSTOM_MASK };
7375

76+
// There are 2 warpgroups so 0x3 and 0x4 are used
77+
enum { SKIP_SOFTMAX_BARRIER = Kernel_traits::SKIP_SOFTMAX_BARRIER_ID };
78+
7479
// Ctor.
7580
template <typename Params>
7681
inline __device__ Softmax_base(Params params, int tidx)
@@ -80,7 +85,12 @@ struct Softmax_base {
8085
sliding_window_size_(params.sliding_window_size),
8186
log2_chunked_attention_size_(params.log2_chunked_attention_size),
8287
packed_mask_ptr_{reinterpret_cast<uint32_t*>(params.packed_mask_ptr)},
83-
params_packed_mask_stride_in_bytes_{params.packed_mask_stride_in_bytes} {
88+
params_packed_mask_stride_in_bytes_{params.packed_mask_stride_in_bytes},
89+
#ifdef SKIP_SOFTMAX_STAT
90+
total_blocks(0),
91+
skipped_blocks(0),
92+
#endif
93+
skip_softmax_threshold(0) {
8494
int warp = tidx / 32;
8595
int lane = tidx % 32;
8696
// The corresponding row/col for each thread after MMA.
@@ -253,25 +263,67 @@ struct Softmax_base {
253263
}
254264

255265
// Calculate max/sum, and update flash-attention scales.
266+
// Returns false if skipped due to skip-softmax attention feature.
256267
template <bool IS_FIRST_COL>
257-
inline __device__ void compute_and_update_scale(float (&global_max)[Mma_tile_p::CORES_M],
258-
float (&global_sum)[Mma_tile_p::CORES_M]) {
268+
inline __device__ bool compute_and_update_scale(float (&global_max)[Mma_tile_p::CORES_M],
269+
float (&global_sum)[Mma_tile_p::CORES_M],
270+
uint32_t* skip_softmax_vote) {
259271
float const scale = reinterpret_cast<float const&>(scale_bmm1_);
260272

273+
// whether this warpgroup skips the softmax
274+
constexpr bool may_skip = Kernel_traits::ENABLE_SKIP_SOFTMAX && !IS_FIRST_COL;
275+
bool skip = may_skip;
276+
261277
// Row-wise max of current tile.
262278
#pragma unroll
263279
for (int mi = 0; mi < Mma_tile_p::CORES_M; mi++) {
264-
if (IS_FIRST_COL) {
265-
local_max_[mi] = elt_[mi][0];
266-
} else {
267-
local_max_[mi] = fmaxf(global_max[mi], elt_[mi][0]);
268-
}
280+
local_max_[mi] = elt_[mi][0];
269281
#pragma unroll
270282
for (int ni = 1; ni < Mma_tile_p::CORES_N * 2; ni++) {
271283
local_max_[mi] = fmaxf(local_max_[mi], elt_[mi][ni]);
272284
}
273285
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 1), local_max_[mi]);
274286
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 2), local_max_[mi]);
287+
288+
if constexpr (may_skip) {
289+
// AND(&) the CORES_M results, then `skip` means whether to skip
290+
// the CORES_M(=2) rows
291+
if constexpr (!EXP2F_OPTIMIZATION) {
292+
skip &= expf(local_max_[mi] - global_max[mi]) < skip_softmax_threshold;
293+
} else {
294+
skip &= exp2f((local_max_[mi] - global_max[mi]) * scale) < skip_softmax_threshold;
295+
}
296+
}
297+
298+
if (!IS_FIRST_COL) {
299+
local_max_[mi] = fmaxf(local_max_[mi], global_max[mi]);
300+
}
301+
}
302+
303+
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX) {
304+
#ifdef SKIP_SOFTMAX_STAT
305+
total_blocks++;
306+
#endif
307+
if constexpr (may_skip) {
308+
// AND(&) the results together in a warp, then `skip` means whether to skip
309+
// all the 16 rows managed by this warp.
310+
// each 4 threads (e.g. T0~T3) have the same `skip`, only 0x11111111 is needed
311+
// instead of 0xffffffff. But the perf is the same.
312+
skip = __all_sync(0xffffffff, skip);
313+
if (threadIdx.x % 32 == 0) {
314+
// The leader of each warp votes.
315+
atomicAnd(skip_softmax_vote, uint32_t(skip));
316+
}
317+
// WG0 uses 0x3 barrier, WG1 uses 0x4 barrier
318+
named_barrier_wait(SKIP_SOFTMAX_BARRIER + threadIdx.x / 128, 128);
319+
skip = *((uint32_t volatile*)skip_softmax_vote);
320+
if (skip) {
321+
#ifdef SKIP_SOFTMAX_STAT
322+
skipped_blocks++;
323+
#endif
324+
return false;
325+
}
326+
}
275327
}
276328

277329
// Softmax Exp.
@@ -339,6 +391,7 @@ struct Softmax_base {
339391
global_max[mi] = max_new;
340392
}
341393
}
394+
return true;
342395
}
343396

344397
// Update flash attention scales and pack elements for BMM2.
@@ -407,6 +460,13 @@ struct Softmax_base {
407460
float correction_[Mma_tile_p::CORES_M];
408461
// The packed mask.
409462
uint4 packed_mask_;
463+
// Skip softmax when exp(local_max - global_max) < skip_softmax_threshold.
464+
float skip_softmax_threshold;
465+
#ifdef SKIP_SOFTMAX_STAT
466+
// Statistics of skip-softmax
467+
uint32_t total_blocks;
468+
uint32_t skipped_blocks;
469+
#endif
410470
};
411471

412472
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -676,29 +736,72 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
676736
inline __device__ Softmax(Params const& params, int tidx) : Base(params, tidx) {}
677737

678738
// Calculate max/sum, and update flash-attention scales.
739+
// Returns false if skipped due to skip-softmax attention feature.
679740
template <bool IS_FIRST_COL>
680-
inline __device__ void compute_and_update_scale(float (&global_max)[Mma_tile_p::CORES_M],
681-
float (&global_sum)[Mma_tile_p::CORES_M]) {
741+
inline __device__ bool compute_and_update_scale(float (&global_max)[Mma_tile_p::CORES_M],
742+
float (&global_sum)[Mma_tile_p::CORES_M],
743+
uint32_t* skip_softmax_vote) {
682744
float const scale = reinterpret_cast<float const&>(this->scale_bmm1_);
683745
float(&local_max_)[Mma_tile_p::CORES_M] = this->local_max_;
684746
float(&local_sum_)[Mma_tile_p::CORES_M] = this->local_sum_;
685747
float(&correction_)[Mma_tile_p::CORES_M] = this->correction_;
686748
float(&elt_)[Mma_tile_p::CORES_M][Mma_tile_p::CORES_N * 2] = this->elt_;
687749

750+
// whether this warpgroup skips the softmax
751+
constexpr bool may_skip = Kernel_traits::ENABLE_SKIP_SOFTMAX && !IS_FIRST_COL;
752+
bool skip = may_skip;
753+
688754
// Row-wise max of current tile.
689755
#pragma unroll
690756
for (int mi = 0; mi < Mma_tile_p::CORES_M; mi++) {
691-
if (IS_FIRST_COL) {
692-
local_max_[mi] = elt_[mi][0];
693-
} else {
694-
local_max_[mi] = fmaxf(global_max[mi], elt_[mi][0]);
695-
}
757+
local_max_[mi] = elt_[mi][0];
696758
#pragma unroll
697759
for (int ni = 1; ni < Mma_tile_p::CORES_N * 2; ni++) {
698760
local_max_[mi] = fmaxf(local_max_[mi], elt_[mi][ni]);
699761
}
700762
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 1), local_max_[mi]);
701763
local_max_[mi] = fmaxf(__shfl_xor_sync(uint32_t(-1), local_max_[mi], 2), local_max_[mi]);
764+
// AND(&) the CORES_M results, then `skip` means whether to skip
765+
// the CORES_M(=2) rows
766+
if constexpr (may_skip) {
767+
// AND(&) the CORES_M results, then `skip` means whether to skip
768+
// the CORES_M(=2) rows
769+
if constexpr (!EXP2F_OPTIMIZATION) {
770+
skip &= expf(local_max_[mi] - global_max[mi]) < this->skip_softmax_threshold;
771+
} else {
772+
skip &= exp2f((local_max_[mi] - global_max[mi]) * scale) < this->skip_softmax_threshold;
773+
}
774+
}
775+
if (!IS_FIRST_COL) {
776+
local_max_[mi] = fmaxf(local_max_[mi], global_max[mi]);
777+
}
778+
}
779+
780+
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX) {
781+
#ifdef SKIP_SOFTMAX_STAT
782+
this->total_blocks++;
783+
#endif
784+
785+
if constexpr (may_skip) {
786+
// AND(&) the results together in a warp, then `skip` means whether to skip
787+
// all the 16 rows managed by this warp.
788+
// each 4 threads (e.g. T0~T3) have the same `skip`, only 0x11111111 is needed
789+
// instead of 0xffffffff. But the perf is the same.
790+
skip = __all_sync(0xffffffff, skip);
791+
if (threadIdx.x % 32 == 0) {
792+
// The leader of each warp votes.
793+
atomicAnd(skip_softmax_vote, uint32_t(skip));
794+
}
795+
// WG0 uses 0x3 barrier, WG1 uses 0x4 barrier
796+
named_barrier_wait(Base::SKIP_SOFTMAX_BARRIER + threadIdx.x / 128, 128);
797+
skip = *((uint32_t volatile*)skip_softmax_vote);
798+
if (skip) {
799+
#ifdef SKIP_SOFTMAX_STAT
800+
this->skipped_blocks++;
801+
#endif
802+
return false;
803+
}
804+
}
702805
}
703806

704807
// Softmax Exp.
@@ -774,6 +877,7 @@ struct Softmax<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
774877
global_max[mi] = max_new;
775878
}
776879
}
880+
return true;
777881
}
778882

779883
// Update flash attention scales and pack elements for BMM2.

csrc/fmha_v2/fmha/warpspec/kernel_traits.h

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ template <
6565
bool ENABLE_BMM1_SOFTCAPPING_SCALE_ = false,
6666
// Save softmax stats ?
6767
bool RETURN_SOFTMAX_STATS_ = false,
68+
// Enable skip softmax attention feature
69+
bool ENABLE_SKIP_SOFTMAX_ = false,
6870
// The output type (only used by fp8 kernels).
6971
typename OutputType = typename Instruction_traits<STEP_Q_, STEP_KV_, 0, false, false>::A_type,
7072
// The sage attention block size for Q, K and V
@@ -189,6 +191,9 @@ struct Kernel_traits {
189191
// Use the custom mask input ( attention_mask_type == 3.)
190192
enum { USE_CUSTOM_MASK = ATTENTION_MASK_TYPE_ == 3 };
191193

194+
// Are we enabling skip softmax attention feature?
195+
enum { ENABLE_SKIP_SOFTMAX = ENABLE_SKIP_SOFTMAX_ };
196+
192197
static_assert(!USE_CUSTOM_MASK || STEP_KV == 64 || STEP_KV == 128 || STEP_KV == 256,
193198
"Not implemented!");
194199

@@ -250,6 +255,8 @@ struct Kernel_traits {
250255
// Named barrier ids
251256
static constexpr int DMA_SYNC_BARRIER_ID = 0x1;
252257
static constexpr int MMA_SYNC_BARRIER_ID = 0x2;
258+
// There are 2 warpgroups so 0x3 and 0x4 are used for skip-softmax
259+
static constexpr int SKIP_SOFTMAX_BARRIER_ID = 0x3;
253260

254261
// How many threads get involved in the dma group.
255262
enum { NUM_THREADS_IN_DMA_GROUP = DMA_GROUP_TRANSPOSE_V ? 128 : (PAGED_KV_INPUT ? 1 : 32) };
@@ -383,6 +390,11 @@ struct Kernel_traits {
383390
// Mutex
384391
OrderedMutex compute_mutex;
385392

393+
// 4 warps in a warpgroup vote to an atomic variable in shared memory
394+
// to decide whether to skip this STEP_KV. Double-buffered to avoid races between consecutive
395+
// KV_STEPS.
396+
uint32_t skip_softmax_votes[2][NUM_COMPUTE_GROUPS];
397+
386398
inline __device__ void init(int tid0) {
387399
#pragma unroll
388400
for (int i = 0; i < NUM_COMPUTE_GROUPS; i++) {
@@ -439,24 +451,27 @@ template < // The step size in query sequence dimension (M of BMM1 and BMM2).
439451
bool ENABLE_BMM1_SOFTCAPPING_SCALE_ = false,
440452
// Save softmax stats ?
441453
bool RETURN_SOFTMAX_STATS_ = false,
454+
// Enable skip softmax attention feature
455+
bool ENABLE_SKIP_SOFTMAX_ = false,
442456
// The output type (only used by fp8 kernels).
443457
typename OutputType = e4m3_t,
444458
// The sage attention block size for Q, K and V
445459
int SAGE_BLOCK_SIZE_Q_ = 0, int SAGE_BLOCK_SIZE_K_ = 0, int SAGE_BLOCK_SIZE_V_ = 0>
446460
struct Kernel_traits_Hopper_qgmma_e4m3_fp32
447-
: public Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_,
448-
KV_BUFFERS_, NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_,
449-
ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_, ENABLE_MUTEX_,
450-
SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_,
451-
ENABLE_BMM1_SOFTCAPPING_SCALE_, RETURN_SOFTMAX_STATS_, OutputType,
452-
SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_> {
461+
: public Kernel_traits<
462+
Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_, KV_BUFFERS_,
463+
NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_,
464+
APPLY_ALIBI_, ENABLE_MUTEX_, SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_,
465+
ENABLE_BMM1_SOFTCAPPING_SCALE_, RETURN_SOFTMAX_STATS_, ENABLE_SKIP_SOFTMAX_, OutputType,
466+
SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_> {
453467
// Base class.
454-
using Base = Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_,
455-
KV_BUFFERS_, NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_,
456-
ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_, ENABLE_MUTEX_,
457-
SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_,
458-
ENABLE_BMM1_SOFTCAPPING_SCALE_, RETURN_SOFTMAX_STATS_, OutputType,
459-
SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_>;
468+
using Base =
469+
Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_,
470+
KV_BUFFERS_, NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_,
471+
HEADS_INTERLEAVED_, APPLY_ALIBI_, ENABLE_MUTEX_, SCHEDULING_MODE_,
472+
INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_,
473+
RETURN_SOFTMAX_STATS_, ENABLE_SKIP_SOFTMAX_, OutputType, SAGE_BLOCK_SIZE_Q_,
474+
SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_>;
460475

461476
enum { USE_TMA_STORE = USE_TMA_STORE_ };
462477

@@ -549,6 +564,11 @@ struct Kernel_traits_Hopper_qgmma_e4m3_fp32
549564
// Mutex
550565
OrderedMutex compute_mutex;
551566

567+
// 4 warps in a warpgroup vote to an atomic variable in shared memory
568+
// to decide whether to skip this STEP_KV. Double-buffered to avoid races between consecutive
569+
// STEP_KVs.
570+
uint32_t skip_softmax_votes[2][Base::NUM_COMPUTE_GROUPS];
571+
552572
inline __device__ void init(int tid0) {
553573
#pragma unroll
554574
for (int i = 0; i < Base::NUM_COMPUTE_GROUPS; i++) {

0 commit comments

Comments
 (0)