Skip to content

Commit 00d2e9f

Browse files
authored
[Backend] Fix more tt.scan layout issues (#9189)
Following #9185, I asked codex to find other issues with regression tests. It hacked around the issue, but this was enough for me to find the real issue and fix it properly. Great team work. We should audit generally other uses of `linearize`/`delinearize` as those that use the legacy APIs will most likely be broken when used with broadcasted layouts.
1 parent 23e4085 commit 00d2e9f

4 files changed

Lines changed: 77 additions & 3 deletions

File tree

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,9 @@ Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
436436
Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
437437
ArrayRef<unsigned> shape);
438438

439+
Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
440+
triton::gpu::LinearEncodingAttr encoding, StringAttr dimName);
441+
439442
size_t linearize(ArrayRef<unsigned> multiDim, ArrayRef<unsigned> shape,
440443
ArrayRef<unsigned> order);
441444

lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -537,9 +537,9 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor,
537537
std::get<0>(getMultiDimLaneId(rewriter, helper, laneId));
538538
multiDimLaneId[helper.getAxis()] = b.i32_val(scanDim - 1);
539539
auto linearEncoding = helper.getEncoding();
540-
auto threadsPerWarp = linearEncoding.getThreadsPerWarp();
541-
auto laneIdLast = linearize(rewriter, loc, multiDimLaneId, threadsPerWarp,
542-
helper.getOrder());
540+
auto kLane = StringAttr::get(rewriter.getContext(), "lane");
541+
Value laneIdLast =
542+
linearize(rewriter, loc, multiDimLaneId, linearEncoding, kLane);
543543
AddPartialReduceOneWarp(srcValues, rewriter, targetInfo, helper, warpIdAxis,
544544
laneIdAxis, laneIdLast);
545545
} // else axisNumWarps == 1 and srcValues.size() == 1, nothing to do.

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,6 +1229,45 @@ Value pext_i32(RewriterBase &rewriter, Location loc, Value a, uint32_t mask) {
12291229
return result;
12301230
}
12311231

1232+
// Puts the bits of `a` that are set in `mask` into the bits of `result`
1233+
Value pdep_i32(RewriterBase &rewriter, Location loc, Value a, uint32_t mask) {
1234+
auto b = TritonLLVMOpBuilder(loc, rewriter);
1235+
assert(a.getType() == i32_ty && "a must be i32");
1236+
1237+
if (mask == 0)
1238+
return b.i32_val(0);
1239+
assert(mask < 64 && "mask must be less than 64");
1240+
1241+
// Blocked algorithm (same grouping trick as the pext example).
1242+
uint32_t mskConst = mask;
1243+
uint32_t depcnt = 0; // how many source bits from `a` we've consumed
1244+
Value result = b.i32_val(0);
1245+
1246+
while (mskConst) {
1247+
uint32_t oldmsk = mskConst;
1248+
1249+
// Isolate lsb set bit, then clear the lowest contiguous run of 1s.
1250+
uint32_t bitgrplsb = mskConst & (~mskConst + 1); // m & -m
1251+
mskConst &= (bitgrplsb + mskConst);
1252+
uint32_t bitgrp = mskConst ^ oldmsk; // the cleared run (contiguous 1s)
1253+
1254+
// Group start position and length.
1255+
uint32_t lsbpos = __builtin_ctz(bitgrplsb);
1256+
uint32_t grplen = __builtin_ctz(~(bitgrp >> lsbpos));
1257+
1258+
// Align the next grplen bits of `a` to the group's lsb, then mask to the
1259+
// group.
1260+
uint32_t shift =
1261+
lsbpos - depcnt; // non-negative invariant for this traversal order
1262+
depcnt += grplen;
1263+
1264+
Value deposited = b.and_(b.shl(a, b.i32_val(shift)), b.i32_val(bitgrp));
1265+
result = b.or_(result, deposited);
1266+
}
1267+
1268+
return result;
1269+
}
1270+
12321271
std::tuple<SmallVector<Value>, Value>
12331272
delinearize(RewriterBase &rewriter, Location loc,
12341273
triton::gpu::DistributedEncodingTrait layout,
@@ -1344,6 +1383,20 @@ Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
13441383
return linear;
13451384
}
13461385

1386+
Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
1387+
triton::gpu::LinearEncodingAttr encoding, StringAttr dimName) {
1388+
auto orderDim = encoding.orderPerDim(dimName, encoding.getOrder());
1389+
auto shapeDim = encoding.basesPerDim(dimName);
1390+
auto linear = linearize(rewriter, loc, multiDim, shapeDim, orderDim);
1391+
auto ll = encoding.getLinearLayout();
1392+
int32_t freeVarMask = ll.getFreeVariableMasks().lookup(dimName);
1393+
if (freeVarMask != 0) {
1394+
int32_t nonFreeVarMask = ~freeVarMask & (ll.getInDimSize(dimName) - 1);
1395+
linear = pdep_i32(rewriter, loc, linear, nonFreeVarMask);
1396+
}
1397+
return linear;
1398+
}
1399+
13471400
size_t linearize(ArrayRef<unsigned> multiDim, ArrayRef<unsigned> shape,
13481401
ArrayRef<unsigned> order) {
13491402
size_t linear = 0;

python/test/gluon/test_lowerings.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,24 @@ def test_scan_blocked_broadcast_layout(device):
111111
torch.testing.assert_close(y, torch.cumsum(x, dim=0))
112112

113113

114+
def test_scan_blocked_broadcast_layout_multiblock(device):
115+
if not is_cuda():
116+
pytest.skip("requires CUDA")
117+
if THREADS_PER_WARP != 32:
118+
pytest.skip("requires 32-thread warps")
119+
120+
M = 64
121+
# Broadcasting in lane for dim1 and multiple scan blocks along axis 0.
122+
src_layout = ttgl.BlockedLayout([2, 4], [16, 2], [1, 2], [1, 0])
123+
124+
torch.manual_seed(0)
125+
x = torch.randn((M, 1), dtype=torch.float32, device=device)
126+
y = torch.empty_like(x)
127+
scan_kernel[(1, )](x, y, M, 1, src_layout, 0, num_warps=2)
128+
129+
torch.testing.assert_close(y, torch.cumsum(x, dim=0))
130+
131+
114132
def _reduce_linear_layouts():
115133
if THREADS_PER_WARP == 32:
116134
return [

0 commit comments

Comments
 (0)