Skip to content

Commit e9968a6

Browse files
committed
revert redundant clearing
1 parent 7e8ecb2 commit e9968a6

1 file changed

Lines changed: 3 additions & 16 deletions

File tree

include/flashinfer/topk.cuh

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -502,23 +502,10 @@ __device__ __forceinline__ OrderedType RadixSelectFromSharedMemory(
502502
barrier_phase++;
503503
__syncthreads();
504504

505-
// CTA 0 clears output counter and first histogram AFTER barrier
506-
// Only clear on iter==0 (buffer might be uninitialized on first kernel launch)
507-
// For iter>0, k>=vocab iterations clear the next histogram at their end
508-
// Per-round clearing handles subsequent rounds within the same iteration
509-
if (cta_in_group == 0) {
510-
if (iter == 0) {
511-
// First iteration: clear first round's histogram (buffer might be uninitialized)
512-
// Per-round clearing will handle histograms for rounds 1-3
513-
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
514-
state->histogram[0][i] = 0;
515-
}
516-
}
517-
if (tx == 0) {
518-
st_release(&state->output_counter, 0);
519-
}
505+
// CTA 0 clears output counter AFTER barrier
506+
if (cta_in_group == 0 && tx == 0) {
507+
st_release(&state->output_counter, 0);
520508
}
521-
__syncthreads(); // Ensure histogram clearing completes before any CTA proceeds
522509
}
523510

524511
// NUM_ROUNDS of radix select

0 commit comments

Comments
 (0)