Commit f937cc8
[ROCm] Reduce RadixSelect sync overhead by moving __syncthreads to findPatternDataSmem (pytorch#178188)
## Summary
PR pytorch#177149 fixed a race condition introduced by pytorch#174837: after `countRadixAggregateCounts` Stage 3 reads counts from smem, warp 0 may get ahead of lagging warps still in Stage 3 and call `findPatternDataSmem`, overwriting `smem[0]`/`smem[1]` while lagging warps are still reading `smem[buffer_offset + i]` (which overlaps with `smem[0]`/`smem[1]` when `buffer_offset == 0`). The fix placed a `__syncthreads()` at the end of Stage 3, which runs on every iteration of the radix digit loop, negating part of the synchronization overhead that pytorch#174837 worked to eliminate.
This patch moves that sync to the **beginning of `findPatternDataSmem`** instead.
## Why this is correct
1. All threads evaluate the same `counts[]` values and all reach `found_unique()` together, so `__syncthreads()` inside `findPatternDataSmem` is collectively reachable by all threads in the block.
2. By the time any thread enters `findPatternDataSmem`, every thread has already finished reading Stage 3 (they all had to evaluate the bucket loop to get here), so syncing before the `smem[0]`/`smem[1]` writes is sufficient to prevent the race.
## Performance
`findPatternDataSmem` is called **at most once** per `radixSelect` invocation — only when `count == 1` (a unique element is identified), at which point the function returns immediately. The removed sync ran on every radix digit iteration (up to 16 times for float32). This saves up to 15 `__syncthreads()` calls in the common case.
Pull Request resolved: pytorch#178188
Approved by: https://github.com/jeffdaily, https://github.com/jeanschmidt1 parent 48d3e2d commit f937cc8
1 file changed
+9
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
523 | 523 | | |
524 | 524 | | |
525 | 525 | | |
526 | | - | |
527 | 526 | | |
528 | 527 | | |
529 | 528 | | |
| |||
694 | 693 | | |
695 | 694 | | |
696 | 695 | | |
| 696 | + | |
| 697 | + | |
| 698 | + | |
| 699 | + | |
| 700 | + | |
| 701 | + | |
| 702 | + | |
| 703 | + | |
| 704 | + | |
697 | 705 | | |
698 | 706 | | |
699 | 707 | | |
| |||
0 commit comments