Skip to content

Commit 4e117a3

Browse files
denix56Denys Senkin
authored andcommitted
Implement scatter op with 'reduce' functionality support
1 parent 0bc402c commit 4e117a3

30 files changed

Lines changed: 2204 additions & 3 deletions

File tree

docs/python-api/triton.language.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ Scan/Sort Ops
152152
sort
153153
topk
154154
gather
155+
scatter
155156

156157
Atomic Ops
157158
----------

include/triton/Analysis/Utility.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,21 @@ class GatherLoweringHelper {
146146
RankedTensorType dstTy;
147147
};
148148

149+
// Helper class for lowering `tt.scatter` operations. This class shares lowering
150+
// logic between shared memory allocation and LLVM codegen.
151+
class ScatterLoweringHelper {
152+
public:
153+
ScatterLoweringHelper(triton::ScatterOp scatterOp);
154+
155+
// Get the shared memory scratch size required by this op.
156+
unsigned getScratchSizeInBytes();
157+
// Determine if the scatter can be performed completely within a warp.
158+
bool isWarpLocal();
159+
160+
private:
161+
triton::ScatterOp scatterOp;
162+
};
163+
149164
// This struct represents the factorization of a warp-local layout conversion
150165
// into three components: a register-only permutation, a lane-only permutation,
151166
// and a set of swaps between lane and register basis vectors. Algebraically, it

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ void populateGatherOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
7676
RewritePatternSet &patterns,
7777
const TargetInfoBase &targetInfo,
7878
PatternBenefit benefit);
79+
void populateScatterOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
80+
RewritePatternSet &patterns,
81+
const TargetInfoBase &targetInfo,
82+
PatternBenefit benefit);
7983

8084
void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
8185
const TargetInfoBase &targetInfo,

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ class TargetInfoBase {
1414

1515
virtual Value ballot(RewriterBase &rewriter, Location loc, Type type,
1616
Value cmp) const = 0;
17+
// Returns the subset of lanes in `activeMask` that have the same `value` as
18+
// the current lane. Backends may override with native match-any support.
19+
virtual Value matchAny(RewriterBase &rewriter, Location loc, Type maskType,
20+
Value value, Value activeMask) const = 0;
1721

1822
// Emit a block/CTA level barrier that guarantees visibility for the
1923
// target address space

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,56 @@ def TT_GatherOp : TT_Op<"gather", [Pure,
960960
let hasVerifier = 1;
961961
}
962962

963+
//
964+
// Scatter Op
965+
//
966+
def TT_ScatterOp : TT_Op<"scatter", [Pure,
967+
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
968+
let summary = "local scatter operation";
969+
let description = [{
970+
Scatter elements from the source tensor into the destination tensor using
971+
the indices tensor along a single specified axis. The source and indices
972+
tensors must have the same shape. The output tensor has the same shape as
973+
the destination tensor.
974+
975+
For each source position I, this writes:
976+
out[I[0], ..., indices[I], ..., I[n]] = src[I]
977+
978+
If a reduction region is provided, then multiple source elements that map
979+
to the same destination are combined using the region. When `include_self`
980+
is true, the original destination value is included in the reduction for
981+
indices that are written. When it is false, only source values are
982+
combined, and destinations with no writes keep their original values.
983+
984+
The `efficient_layout` attribute is set when the compiler has determined an
985+
optimized layout for the operation, indicating that it should not be
986+
changed.
987+
}];
988+
989+
let arguments = (ins
990+
TT_Tensor:$dst,
991+
TT_IntTensor:$indices,
992+
TT_Tensor:$src,
993+
I32Attr:$axis,
994+
DefaultValuedAttr<BoolAttr, "true">:$include_self,
995+
OptionalAttr<StrAttr>:$reduce_kind,
996+
UnitAttr:$efficient_layout
997+
);
998+
let results = (outs TT_Tensor:$result);
999+
let regions = (region AnyRegion:$combineOp);
1000+
1001+
let hasVerifier = 1;
1002+
let hasRegionVerifier = 1;
1003+
let hasCustomAssemblyFormat = 1;
1004+
}
1005+
1006+
def TT_ScatterReturnOp: TT_Op<"scatter.return",
1007+
[HasParent<"ScatterOp">, Pure, Terminator, ReturnLike]> {
1008+
let summary = "terminator for scatter reduction operator";
1009+
let arguments = (ins Variadic<AnyType>:$result);
1010+
let assemblyFormat = "$result attr-dict `:` type($result)";
1011+
}
1012+
9631013
//
9641014
// Print Op
9651015
//

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def TritonGPUOptimizeThreadLocality : Pass<"tritongpu-optimize-thread-locality",
268268

269269
let description = [{
270270
The aim of this pass is to reduce cross-thread communication for certain
271-
operations, like reductions, reshapes, and gathers.
271+
operations, like reductions, reshapes, gathers, and scatters.
272272

273273
For reduction operations, this pass attempts to adjust the reduction size
274274
(or layout) to avoid splitting the reduction operation between multiple
@@ -281,6 +281,10 @@ def TritonGPUOptimizeThreadLocality : Pass<"tritongpu-optimize-thread-locality",
281281
heuristics to determine when it is appropriate to assign specific layouts
282282
and trigger their respective codegen paths. For now, the pass only attempts
283283
to apply layouts that result in warp-synchronous gathers.
284+
285+
For scatters, this pass applies the same strategy and attempts to assign
286+
layouts that keep source/index updates warp-synchronous with their
287+
destination columns.
284288
}];
285289

286290
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",

lib/Analysis/Allocation.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
8686
GatherLoweringHelper helper(gatherOp);
8787
return helper.getScratchSizeInBytes();
8888
}
89+
if (auto scatterOp = dyn_cast<ScatterOp>(op)) {
90+
ScatterLoweringHelper helper(scatterOp);
91+
return helper.getScratchSizeInBytes();
92+
}
8993
if (auto histogram = dyn_cast<HistogramOp>(op)) {
9094
auto dstTy = histogram.getType();
9195
int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp(

lib/Analysis/Utility.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,69 @@ bool GatherLoweringHelper::isWarpLocal() {
885885
idxLayout.sublayout(kLane, otherDims);
886886
}
887887

888+
ScatterLoweringHelper::ScatterLoweringHelper(triton::ScatterOp scatterOp)
889+
: scatterOp(scatterOp) {}
890+
891+
unsigned ScatterLoweringHelper::getScratchSizeInBytes() {
892+
// Otherwise, scattering will write into a temporary output tensor in shared
893+
// memory before materializing the final output registers.
894+
RankedTensorType dstType = scatterOp.getDst().getType();
895+
unsigned elemBytes = ceil<unsigned>(dstType.getElementTypeBitWidth(), 8);
896+
unsigned dstBytes = product(dstType.getShape()) * elemBytes;
897+
bool hasCombine =
898+
!scatterOp.getCombineOp().empty() || scatterOp.getReduceKindAttr();
899+
if (hasCombine) {
900+
// Extra i32 per element for CAS-based locks/flags.
901+
dstBytes += product(dstType.getShape()) * sizeof(int32_t);
902+
}
903+
return dstBytes;
904+
}
905+
906+
bool ScatterLoweringHelper::isWarpLocal() {
907+
// The scatter is warp-local if all source/index writes for any destination
908+
// column can be serviced within a single warp.
909+
RankedTensorType dstType = scatterOp.getDst().getType();
910+
RankedTensorType srcType = scatterOp.getSrc().getType();
911+
RankedTensorType idxType = scatterOp.getIndices().getType();
912+
LinearLayout dstLayout = toLinearLayout(dstType);
913+
LinearLayout srcLayout = toLinearLayout(srcType);
914+
LinearLayout idxLayout = toLinearLayout(idxType);
915+
916+
Builder b(scatterOp.getContext());
917+
StringAttr kBlock = b.getStringAttr("block");
918+
StringAttr kWarp = b.getStringAttr("warp");
919+
StringAttr kLane = b.getStringAttr("lane");
920+
StringAttr kScatterDim =
921+
b.getStringAttr("dim" + std::to_string(scatterOp.getAxis()));
922+
923+
// The scatter dimension must be invariant with respect to warp/block in all
924+
// participating tensors.
925+
if (!dstLayout.sublayoutIsZero({kBlock, kWarp}, kScatterDim) ||
926+
!srcLayout.sublayoutIsZero({kBlock, kWarp}, kScatterDim) ||
927+
!idxLayout.sublayoutIsZero({kBlock, kWarp}, kScatterDim))
928+
return false;
929+
930+
SmallVector<StringAttr> otherDims;
931+
for (unsigned dim = 0, rank = dstType.getRank(); dim < rank; ++dim) {
932+
if (dim != scatterOp.getAxis()) {
933+
otherDims.push_back(b.getStringAttr("dim" + Twine(dim)));
934+
}
935+
}
936+
937+
// Source/index and destination columns must line up identically across warps.
938+
if (dstLayout.sublayout({kBlock, kWarp}, otherDims) !=
939+
srcLayout.sublayout({kBlock, kWarp}, otherDims) ||
940+
dstLayout.sublayout({kBlock, kWarp}, otherDims) !=
941+
idxLayout.sublayout({kBlock, kWarp}, otherDims))
942+
return false;
943+
944+
// Require lane ownership of columns to match for simpler codegen.
945+
return dstLayout.sublayout(kLane, otherDims) ==
946+
srcLayout.sublayout(kLane, otherDims) &&
947+
dstLayout.sublayout(kLane, otherDims) ==
948+
idxLayout.sublayout(kLane, otherDims);
949+
}
950+
888951
unsigned getNumScratchElements(ArrayRef<unsigned> shape) {
889952
if (shape.empty())
890953
return 0;

lib/Conversion/TritonGPUToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ add_triton_library(TritonGPUToLLVM
1818
PrintOpToLLVM.cpp
1919
ReduceOpToLLVM.cpp
2020
ScanOpToLLVM.cpp
21+
ScatterOpToLLVM.cpp
2122
SPMDOpToLLVM.cpp
2223
TypeConverter.cpp
2324
Utility.cpp

0 commit comments

Comments
 (0)