Skip to content

Commit 40425d1

Browse files
minsiifacebook-github-bot
authored andcommitted
Enable native PAT AVG (#511)
Summary: Implements native AVG support for the PAT (Parallel All-to-All Transpose) algorithm in ReduceScatter. ## Problem Baseline NCCL doesn't support PAT + AVG op. - if `NCCL_ALGO=reducescatter:pat` is NOT SET: NCCL would fallback back to Ring algorithm - if `NCCL_ALGO=reducescatter:pat` is SET: fail with `ncclInvalidUsage - Error : no algorithm/protocol available for function` ## Solution This diff enables native AVG with PAT algorithm for reduce scatter. It divides integer nRanks at last step when writing final sum result into recvbuf. **Documentation Added:** - `meta/collectives/docs/ReduceScatterPat.md` - Comprehensive PAT algorithm documentation including 5-phase breakdown and 8-rank visualization - `meta/collectives/docs/ReduceScatterPatAvg.md` - PAT AVG design details, multi-chunk handling, and implementation notes **Key Implementation:** - Add `isFinalWrite` flag to `ncclPatStep` struct (set in Phase 4) to correctly apply division for all chunks in multi-chunk transfers (fixes large message bug) - Add FuncPatAvg<T> template that uses FuncSum for reduction and applies division as a postOp in final write step - Add ncclDevPatAvg enum for kernel dispatch - Update generate.py and def_build.bzl for PatAvg kernel generation - Enable via NCCL_ALGO=reducescatter:pat_postdiv **Meta overlay pattern used to minimize upstream changes:** - meta/device/FuncPatAvg.cuh: Full implementation (~120 lines) - meta/collectives/PatAvgAlgoHelper.h: Helper functions with lazy env detection - All src/ changes (~15 lines) are marked with `[META:PAT_AVG]` comments for rebasing tracking Differential Revision: D91948601
1 parent b7f7f29 commit 40425d1

File tree

14 files changed

+736
-26
lines changed

14 files changed

+736
-26
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#pragma once
4+
5+
#include "comms/utils/cvars/nccl_cvars.h"
6+
#include "device.h"
7+
#include "info.h"
8+
9+
namespace ncclx {
10+
11+
// Check if PAT AVG mode is enabled via dedicated CVAR.
12+
inline bool isPatAvgEnabled() {
13+
return NCCL_REDUCESCATTER_PAT_AVG_ENABLE;
14+
}
15+
16+
// Check if PAT algorithm should be skipped for a given reduction operation.
17+
// PAT doesn't support PreMulSum or SumPostDiv natively, but when PAT AVG
18+
// is enabled, both are converted to ncclDevPatAvg and handled by PAT.
19+
inline bool shouldSkipPatForReduceOp(ncclDevRedOp_t op) {
20+
if (isPatAvgEnabled()) {
21+
// When PAT AVG is enabled, PAT supports all AVG operations
22+
// (both PreMulSum and SumPostDiv will be converted to PatAvg)
23+
return false;
24+
}
25+
// Without PAT AVG, skip PAT for all AVG-related ops
26+
return op == ncclDevPreMulSum || op == ncclDevSumPostDiv;
27+
}
28+
29+
// Switch opDev to ncclDevPatAvg when native PAT AVG is enabled.
30+
// This should be called after algorithm selection in topoGetAlgoInfo().
31+
// nRanks is needed to set scalarArg correctly for FuncPatAvg.
32+
inline void maybeEnablePatAvg(struct ncclTaskColl* info, int nRanks) {
33+
if (info->algorithm == NCCL_ALGO_PAT && info->func == ncclFuncReduceScatter &&
34+
(info->opDev.op == ncclDevSumPostDiv ||
35+
info->opDev.op == ncclDevPreMulSum) &&
36+
isPatAvgEnabled()) {
37+
info->opDev.op = ncclDevPatAvg;
38+
// FuncPatAvg expects opArg = nRanks (just the integer count)
39+
info->opDev.scalarArg = static_cast<uint64_t>(nRanks);
40+
}
41+
}
42+
43+
} // namespace ncclx
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# PAT ReduceScatter Algorithm
2+
3+
## Overview
4+
5+
PAT (Pairwise Algorithm Tree) ReduceScatter implements a recursive halving algorithm that efficiently reduces and scatters data across ranks. For N ranks, each portion completes in log2(N) steps.
6+
7+
## Key Data Structures
8+
9+
### ncclPatStep (collectives.h)
10+
11+
```cpp
12+
struct ncclPatStep {
13+
int recvDim, sendDim; // Dimension for recv/send (-1 = none/local)
14+
int recvOffset, sendOffset; // Offset within peer buffer
15+
int stepOffset; // Step offset for pipelining
16+
int postRecv, postSend; // Post flags for completion
17+
int nelem, last, flags; // Element count, completion, status
18+
bool isFinalWrite; // True if final write for a chunk (apply division for AVG)
19+
size_t inpIx, outIx; // Input/output buffer indices
20+
};
21+
```
22+
23+
### PatRSAlgorithm Class Members
24+
25+
```cpp
26+
offset, end // Current chunk range being processed
27+
count // Total elements per rank in input buffer
28+
chunkCount // Elements per chunk iteration
29+
nelem // Actual elements this iteration
30+
rank, nranks // This rank's ID and total rank count
31+
nrPow2 // Next power of 2 >= nranks
32+
aggFactor // Aggregation factor (batches multiple steps)
33+
aggDelta // Step delta = nrPow2 / aggFactor
34+
as // Current aggregated step index
35+
a // Sub-step within aggregated step
36+
phase // Current algorithm phase (0-4)
37+
scale // Scaling factor for phases 2-3
38+
```
39+
40+
## recvDim and sendDim Encoding
41+
42+
```
43+
recvDim = -1 -> No receive (use local data only as source)
44+
recvDim >= 0 -> Receive from peer along hypercube dimension N
45+
46+
sendDim = -1 -> Write to LOCAL output buffer (userOutput + outIx)
47+
sendDim >= 0 -> Send to peer along hypercube dimension N
48+
```
49+
50+
The dimension corresponds to hypercube edges. For 8 ranks:
51+
- Dim 0: pairs (0,1), (2,3), (4,5), (6,7) - rank XOR 1
52+
- Dim 1: pairs (0,2), (1,3), (4,6), (5,7) - rank XOR 2
53+
- Dim 2: pairs (0,4), (1,5), (2,6), (3,7) - rank XOR 4
54+
55+
## The 5 Phases
56+
57+
The algorithm uses 5 phases organized into two groups:
58+
59+
### Primary Reduction (Phases 0-1)
60+
Handles the main recursive halving, processing odd-indexed `as` values.
61+
62+
| Phase | Description | recvDim | sendDim |
63+
|-------|-------------|---------|---------|
64+
| **0** | Initial scatter: D2D copy from input, send to dim 0 peer | -1 (none) | 0 |
65+
| **1** | Recursive halving: receive from dimension, reduce, forward to next dimension | `firstBitSet(s)` | `firstBitSet(s')` or -1 |
66+
67+
### Secondary Reduction (Phases 2-3)
68+
Activated when `aggFactor > 1`. Forms a butterfly pattern at increasing scales to complete the reduction.
69+
70+
| Phase | Description | recvDim | sendDim |
71+
|-------|-------------|---------|---------|
72+
| **2** | Receive from dim 0 peer, reduce, forward to higher dimension | 0 | `firstBitSet(s)` or -1 |
73+
| **3** | Receive from higher dimension, reduce, forward or write locally | `firstBitSet(s)` | `firstBitSet(s')` or -1 |
74+
75+
Phases 2-3 loop with `scale` doubling each iteration until `scale >= aggFactor`.
76+
77+
### Finalization (Phase 4)
78+
79+
| Phase | Description | recvDim | sendDim |
80+
|-------|-------------|---------|---------|
81+
| **4** | Final receive from dim 0, reduce, write to output buffer | 0 | -1 (local) |
82+
83+
## 8-Rank ReduceScatter Example
84+
85+
### Setup
86+
87+
```
88+
Ranks: 0, 1, 2, 3, 4, 5, 6, 7
89+
nrPow2 = 8
90+
Dimensions: 0, 1, 2 (log2(8) = 3 dimensions)
91+
92+
Input: Each rank has input[0..7] (8 portions)
93+
Output: Rank r gets reduced sum of all ranks' input[r]
94+
```
95+
96+
### 3-Step Recursive Halving for portion[0] -> R0
97+
98+
```
99+
STEP 1: Dim 0 exchange (pairs: 0<->1, 2<->3, 4<->5, 6<->7)
100+
================================================================================
101+
102+
R0 R1 R2 R3 R4 R5 R6 R7
103+
[0_0] [0_1] [0_2] [0_3] [0_4] [0_5] [0_6] [0_7]
104+
| | | | | | | |
105+
+-----+-----+ +-----+-----+ +-----+-----+ +-----+-----+
106+
| | | |
107+
v v v v
108+
[S0_{0,1}] [S0_{2,3}] [S0_{4,5}] [S0_{6,7}]
109+
at R0 at R2 at R4 at R6
110+
111+
112+
STEP 2: Dim 1 exchange (pairs: 0<->2, 1<->3, 4<->6, 5<->7)
113+
================================================================================
114+
115+
R0 R2 R4 R6
116+
[S0_{0,1}] [S0_{2,3}] [S0_{4,5}] [S0_{6,7}]
117+
| | | |
118+
+-----------+-----------+ +-----------+-----------+
119+
| |
120+
v v
121+
[S0_{0,1,2,3}] [S0_{4,5,6,7}]
122+
at R0 at R4
123+
124+
125+
STEP 3: Dim 2 exchange (pairs: 0<->4, 1<->5, 2<->6, 3<->7)
126+
================================================================================
127+
128+
R0 R4
129+
[S0_{0,1,2,3}] [S0_{4,5,6,7}]
130+
| |
131+
+-----------------------+-----------------------+
132+
|
133+
v
134+
[S0_{all 8 ranks}]
135+
at R0
136+
|
137+
/8 (AVG)
138+
|
139+
v
140+
R0.OUTPUT = AVG
141+
```
142+
143+
### All 8 Portions in Parallel (same 3 steps)
144+
145+
```
146+
STEP 1 (Dim 0): Each dim0 pair reduces
147+
--------------------------------------------------------------------------------
148+
portion[0]: R0,R1 -> R0 portion[1]: R0,R1 -> R1
149+
portion[2]: R2,R3 -> R2 portion[3]: R2,R3 -> R3
150+
portion[4]: R4,R5 -> R4 portion[5]: R4,R5 -> R5
151+
portion[6]: R6,R7 -> R6 portion[7]: R6,R7 -> R7
152+
153+
STEP 2 (Dim 1): Each dim1 pair reduces
154+
--------------------------------------------------------------------------------
155+
portion[0]: R0,R2 -> R0 portion[1]: R1,R3 -> R1
156+
portion[2]: R0,R2 -> R2 portion[3]: R1,R3 -> R3
157+
portion[4]: R4,R6 -> R4 portion[5]: R5,R7 -> R5
158+
portion[6]: R4,R6 -> R6 portion[7]: R5,R7 -> R7
159+
160+
STEP 3 (Dim 2): Each dim2 pair reduces, FINAL destination reached
161+
--------------------------------------------------------------------------------
162+
portion[0]: R0,R4 -> R0 portion[1]: R1,R5 -> R1
163+
portion[2]: R2,R6 -> R2 portion[3]: R3,R7 -> R3
164+
portion[4]: R0,R4 -> R4 portion[5]: R1,R5 -> R5
165+
portion[6]: R2,R6 -> R6 portion[7]: R3,R7 -> R7
166+
167+
All portions complete in 3 steps! Apply /8 for AVG.
168+
```
169+
170+
171+
## Buffer Operations in Device Code (prims_simple.h)
172+
173+
```cpp
174+
// Sources setup:
175+
if (recv) {
176+
srcs[0] = peer->buff + recvOffset; // Received data from peer, stored in tmp buffer
177+
}
178+
if (send && sendDim >= 0) {
179+
dsts[0] = peer->buff + sendOffset; // Send tmp buffer
180+
srcs[1] = userInput + inpIx; // Local contribution
181+
}
182+
if (sendDim < 0) { // Local write (phase 4 or intermediate)
183+
dsts[0] = userOutput + outIx; // Output buffer
184+
srcs[1] = userInput + inpIx; // Local contribution
185+
}
186+
187+
// Reduce: srcs[0] (received) + srcs[1] (local) -> dsts[0]
188+
reduceCopy(..., nSrcs, srcs, 1, dsts, ...);
189+
```
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# PAT AVG Design
2+
3+
For the PAT algorithm details (phases, data flow, buffer addressing), see [ReduceScatterPat.md](ReduceScatterPat.md).
4+
5+
## Original AVG Limitation
6+
7+
The original NCCL AVG implementation (`FuncSumPostDiv`) has a critical limitation:
8+
9+
**Only supports unsigned integer types** (uint8, uint32, uint64)
10+
11+
This is because `FuncSumPostDiv` uses integer division which:
12+
- Truncates results for signed integers (incorrect for negative values)
13+
- Cannot represent fractional results for floating-point types
14+
- Is fundamentally incompatible with float, half, bfloat16, and fp8 types
15+
16+
## Overview
17+
18+
FuncPatAvg provides native average (division) support for the PAT (Partition Aggregation Tree) algorithm. Unlike `FuncSumPostDiv` which only supports unsigned integers, `FuncPatAvg` supports all data types including float, half, bfloat16, and fp8.
19+
20+
## Key Design: Two-Phase Approach
21+
22+
Reduction is pure sum; division is applied as postOp on final write only.
23+
24+
## 1. Apply_Reduce / Apply_PostOp Traits
25+
26+
The kernel uses trait classes to customize behavior per reduction function:
27+
28+
```cpp
29+
// Kernel calls generic helpers (common_kernel.h):
30+
acc = applyReduce(redFn, acc, tmp); // -> Apply_Reduce<Fn>::reduce()
31+
acc = applyPostOp(redFn, acc); // -> Apply_PostOp<Fn>::postOp()
32+
33+
// FuncPatAvg specializations (meta/device/FuncPatAvg.cuh):
34+
Apply_Reduce<FuncPatAvg<T>> -> delegates to FuncSum (pure addition)
35+
Apply_PostOp<FuncPatAvg<T>> -> fn.divide(x) (divide by nRanks)
36+
```
37+
38+
## 2. Host-Side Dispatch
39+
40+
```cpp
41+
// enqueue.cc:1873 - after algorithm selection
42+
ncclx::maybeEnablePatAvg(info, comm->nRanks);
43+
44+
// PatAvgAlgoHelper.h
45+
void maybeEnablePatAvg(ncclTaskColl* info, int nRanks) {
46+
if (info->algorithm == NCCL_ALGO_PAT && // PAT selected
47+
info->func == ncclFuncReduceScatter && // ReduceScatter
48+
(info->opDev.op == ncclDevSumPostDiv || // AVG operation
49+
info->opDev.op == ncclDevPreMulSum) &&
50+
isPatAvgEnabled()) { // NCCL_REDUCESCATTER_PAT_AVG_ENABLE=1
51+
info->opDev.op = ncclDevPatAvg; // Switch to PatAvg
52+
info->opDev.scalarArg = nRanks; // Pass nRanks for division
53+
}
54+
}
55+
```
56+
57+
## 3. Kernel-Side Division at Final Write Step
58+
59+
Large messages are split into multiple chunks, each processed independently through the PAT algorithm. The key challenge is: **when should division be applied?**
60+
61+
- Division must happen exactly once per output element
62+
- Each chunk goes through multiple phases (0-4)
63+
- Multiple local writes occur during processing (Phase 1 intermediate writes)
64+
- Only the final write for each chunk should trigger division
65+
66+
### Solution: isFinalWrite Flag
67+
68+
Added explicit `isFinalWrite` flag to `ncclPatStep` struct:
69+
70+
```cpp
71+
struct ncclPatStep {
72+
// ... other fields ...
73+
bool isFinalWrite; // True if final write for a chunk
74+
};
75+
```
76+
77+
The flag is set only in Phase 4, which is the final write phase for each chunk. See phase explanation in [ReduceScatterPat.md](ReduceScatterPat.md).
78+
79+
```cpp
80+
// PatRSAlgorithm::getNextOp(), Phase 4:
81+
} else if (phase == 4) {
82+
ps->recvDim = 0;
83+
ps->sendDim = -1;
84+
ps->isFinalWrite = true; // Division applied here
85+
offset += chunkCount; // Move to next chunk
86+
}
87+
```
88+
89+
PostOp application uses this flag directly:
90+
91+
```cpp
92+
// prims_simple.h patReduce():
93+
const int applyPostOp = ps->isFinalWrite;
94+
```
95+
96+
### Write Types During PAT Execution
97+
98+
| Write Type | Phase | sendDim | isFinalWrite | PostOp Applied |
99+
|------------|-------|---------|--------------|----------------|
100+
| Send to peer | 0-3 | >= 0 | false | No |
101+
| Intermediate local write | 1 | -1 | false | No (partial sum) |
102+
| Final chunk write | 4 | -1 | true | Yes (divide by nRanks) |
103+
104+
Phase 4 is the ONLY phase where all contributions have been accumulated, making it the correct place to apply division.
105+
106+
## Key Files
107+
108+
| File | Purpose |
109+
|------|---------|
110+
| `meta/device/FuncPatAvg.cuh` | FuncPatAvg definition, Apply_Reduce/Apply_PostOp traits |
111+
| `meta/collectives/PatAvgAlgoHelper.h` | Host-side dispatch logic |
112+
| `src/include/collectives.h` | PatRSAlgorithm with isFinalWrite flag in ncclPatStep |
113+
| `src/device/prims_simple.h` | patReduce() applies postOp based on isFinalWrite |
114+
| `src/device/reduce_kernel.h` | Base trait definitions, applyReduce/applyPostOp helpers |
115+
| `src/device/common_kernel.h` | reduceCopyPacks kernel that calls the traits |
116+
| `src/enqueue.cc` | Calls maybeEnablePatAvg after algorithm selection |
117+
118+
## Enabling PAT AVG
119+
120+
Set environment variables:
121+
```bash
122+
export NCCL_ALGO="reducescatter:pat"
123+
export NCCL_REDUCESCATTER_PAT_AVG_ENABLE=1
124+
```
125+
126+
The `isPatAvgEnabled()` function reads the `NCCL_REDUCESCATTER_PAT_AVG_ENABLE` CVAR
127+
to determine if native PAT AVG support should be used. This is a clean, dedicated
128+
control that works properly with the standard NCCL_ALGO algorithm parser.

0 commit comments

Comments
 (0)