Skip to content

Commit f937cc8

Browse files
apakbinjerrymannil
authored andcommitted
[ROCm] Reduce RadixSelect sync overhead by moving __syncthreads to findPatternDataSmem (pytorch#178188)
## Summary PR pytorch#177149 fixed a race condition introduced by pytorch#174837: after `countRadixAggregateCounts` Stage 3 reads counts from smem, warp 0 may get ahead of lagging warps still in Stage 3 and call `findPatternDataSmem`, overwriting `smem[0]`/`smem[1]` while lagging warps are still reading `smem[buffer_offset + i]` (which overlaps with `smem[0]`/`smem[1]` when `buffer_offset == 0`). The fix placed a `__syncthreads()` at the end of Stage 3, which runs on every iteration of the radix digit loop, negating part of the synchronization overhead that pytorch#174837 worked to eliminate. This patch moves that sync to the **beginning of `findPatternDataSmem`** instead. ## Why this is correct 1. All threads evaluate the same `counts[]` values and all reach `found_unique()` together, so `__syncthreads()` inside `findPatternDataSmem` is collectively reachable by all threads in the block. 2. By the time any thread enters `findPatternDataSmem`, every thread has already finished reading Stage 3 (they all had to evaluate the bucket loop to get here), so syncing before the `smem[0]`/`smem[1]` writes is sufficient to prevent the race. ## Performance `findPatternDataSmem` is called **at most once** per `radixSelect` invocation — only when `count == 1` (a unique element is identified), at which point the function returns immediately. The removed sync ran on every radix digit iteration (up to 16 times for float32). This saves up to 15 `__syncthreads()` calls in the common case. Pull Request resolved: pytorch#178188 Approved by: https://github.com/jeffdaily, https://github.com/jeanschmidt
1 parent 48d3e2d commit f937cc8

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

aten/src/ATen/native/cuda/SortingRadixSelect.cuh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,6 @@ __device__ __forceinline__ void countRadixAggregateCounts(
523523
for (uint32_t i = 0; i < RadixSize; ++i) {
524524
counts[i] = smem[buffer_offset + i];
525525
}
526-
__syncthreads(); // Wait for all threads to finish reading the final counts.
527526
}
528527

529528
// This function counts the distribution of all input values in a
@@ -694,6 +693,15 @@ __device__ scalar_t findPatternDataSmem(
694693
const scalar_t* dataSmem, // input data stored in shared memory.
695694
index_t dataSmemSize) { // input data size stored in shared memory.
696695

696+
// Ensure all threads have finished reading from smem before overwriting it.
697+
// countRadixAggregateCounts Stage 3 reads from smem[buffer_offset + i];
698+
// when buffer_offset == 0, those locations overlap with smem[0]/smem[1]
699+
// written below. Warp 0 (which writes smem[0]/smem[1]) may get ahead of
700+
// lagging warps still in Stage 3. Syncing here (rather than at the end of
701+
// Stage 3) is cheaper because findPatternDataSmem is called at most once per
702+
// radixSelect invocation, only when a unique element is found (count == 1).
703+
__syncthreads();
704+
697705
// initialize smem to 0.
698706
// smem[0] is a flag to indicate if a value has been found.
699707
// smem[1] is the found value.

0 commit comments

Comments
 (0)