Skip to content

Commit 0bc402c

Browse files
authored
[ConSan] First pass at improving ConSan compile times (#9366)
When compiling `01-attention-forward.py` with consan enabled, before: ``` # times.ir_initialization=790100 0.79 sec # stage='ttgir' duration=182893 0.18 sec # stage='llir' duration=99643940 99.6 sec # stage='llvmir' duration=27658211 27.7 sec # stage='ptx' duration=11120513 11.1 sec # stage='cubin' duration=1149355658 19.15 min ``` After: ``` # times.ir_initialization=796533 0.79 sec # stage='ttgir' duration=184720 0.18 sec # stage='llir' duration=16957192 17.0 sec # stage='llvmir' duration=3735579 3.74 sec # stage='ptx' duration=1972309 1.97 sec # stage='cubin' duration=34357675 34.36 sec ``` This PR does quite a number of things at once: * Custom CanonicalizeLLVMIR pass that adds a pattern for `select %false|%true, %a, %b` since LLVM dialect is missing this (and is opposed to adding it) * Cache global constants to avoid creating many copies of the same string when lowering asserts * Fix warp specialize lowering to handle function calls and deduplicate barrier lowering code between NVIDIA and AMD backends. To support function calls, non-kernel functions are rewritten to accept a barrier handle argument that is passed down from the call site * Rewrite `createMultiColumnMask` to generate a constant tensor rather than computing it from a bunch of `make_range` and masking. This single function was generating gigabytes of IR * Pick warp-local layouts in consan instrumentation. Previously, consan used thread-local layouts where every thread has a copy of the tensor. This was to avoid using shared memory. We can switch to warp-local layouts where each warp has a copy of the tensor distributed across its threads to reduce the generated IR (and register usage) by a factor 32, plus some extra IR needed for shuffles. * To support warp-local layouts, I added a two flags: `uniform` to `tt.assert` and replaced `tti.experiment_assert_in_thread` with a `tt.reduce` + `tt.assert uniform`. Uniform just means only the first thread in the warp group will trigger the assert since the condition is uniform. * I added `always_use_warp_shuffle` function-level flag to force `convert_layout` lowering to use warp shuffles even when the performance heuristic picks shared memory to avoid using shared memory for the layout conversions inside consan helpers * Changed the lowering of `arith.constant` with non-splat dense elements attribute to generate a constant global array where each thread loads into it * Generate global stores in the main function into a helper function to reduce bloat by deduplicating. This also enables separate compilation later.
1 parent 4804627 commit 0bc402c

32 files changed

Lines changed: 1079 additions & 753 deletions

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8686
mlir::triton::gpu::registerAllocateSharedMemoryPass();
8787
mlir::triton::gpu::registerTritonGPUAllocateWarpGroups();
8888
mlir::triton::gpu::registerTritonGPUGlobalScratchAllocationPass();
89+
mlir::triton::gpu::registerCanonicalizeLLVMIR();
8990
mlir::triton::registerConvertWarpSpecializeToLLVM();
9091
mlir::triton::registerConvertTritonGPUToLLVMPass();
9192
mlir::triton::registerConvertNVGPUToLLVMPass();

include/triton/Conversion/TritonGPUToLLVM/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,8 @@ def TritonGPUAllocateWarpGroups : Pass<"tritongpu-allocate-warp-groups", "mlir::
4242
}];
4343
}
4444

45+
def CanonicalizeLLVMIR : Pass<"canonicalize-llvm-ir", "mlir::LLVM::LLVMFuncOp"> {
46+
let summary = "Canonicalize LLVM IR";
47+
}
48+
4549
#endif

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,7 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter,
103103
PatternBenefit benefit);
104104

105105
void populateInstrumentationToLLVMPatterns(LLVMTypeConverter &typeConverter,
106-
const TargetInfoBase &targetInfo,
107-
RewritePatternSet &patterns,
108-
PatternBenefit benefit);
106+
RewritePatternSet &patterns);
109107

110108
} // namespace triton
111109
} // namespace mlir

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,9 @@ LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
342342

343343
// Multiply a square layout with 1 input and output dimension with a vector
344344
Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x);
345+
346+
// Whether the convert layout should be forced to use warp shuffles.
347+
bool cvtAlwaysUseWarpShuffle(triton::gpu::ConvertLayoutOp cvt);
345348
} // namespace gpu
346349

347350
} // namespace triton
@@ -442,6 +445,9 @@ Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
442445
size_t linearize(ArrayRef<unsigned> multiDim, ArrayRef<unsigned> shape,
443446
ArrayRef<unsigned> order);
444447

448+
GlobalOp getOrInsertGlobalConstant(RewriterBase &rewriter, ModuleOp module,
449+
Type type, Attribute content, StringRef key);
450+
445451
Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key,
446452
StringRef content);
447453

@@ -630,6 +636,14 @@ SmallVector<Value> inlineRegion(RewriterBase &rewriter, Region &region,
630636
mlir::TypeID::get<TerminatorOp>(), loc);
631637
}
632638

639+
// #prevBlock
640+
// if (condition) {
641+
// #ifBlock
642+
// }
643+
// #thenBlock
644+
std::tuple</*prevBlock=*/Block *, /*ifBlock=*/Block *, /*thenBlock=*/Block *>
645+
createIfBlock(ConversionPatternRewriter &b, Location loc, Value cnd);
646+
633647
void finalizeTensorAtomicResults(Operation *op, RankedTensorType tensorTy,
634648
ConversionPatternRewriter &rewriter,
635649
SmallVector<Value> &resultVals,

include/triton/Conversion/TritonGPUToLLVM/WarpSpecializeUtility.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,39 @@
1010
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1111
#include "llvm/ADT/SetVector.h"
1212
#include <functional>
13+
#include <optional>
1314

1415
namespace mlir {
1516
namespace triton {
1617

1718
// Forward declaration
1819
class TritonLLVMIRRewriter;
1920

21+
//===----------------------------------------------------------------------===//
22+
// lowerWarpSpecializeBarriers
23+
//===----------------------------------------------------------------------===//
24+
25+
class WarpSpecializeBarrierHelper {
26+
public:
27+
virtual ~WarpSpecializeBarrierHelper() = default;
28+
29+
virtual bool isBarrierOp(Operation *op) const = 0;
30+
virtual Type getBarrierHandleType(MLIRContext *ctx) const = 0;
31+
virtual FailureOr<Value>
32+
getBarrierHandle(TritonLLVMIRRewriter &b,
33+
std::optional<unsigned> partitionIdx) = 0;
34+
virtual void createBarrier(TritonLLVMIRRewriter &b, unsigned numWarps,
35+
Value handle) = 0;
36+
LogicalResult createBarrier(TritonLLVMIRRewriter &b, unsigned numWarps,
37+
std::optional<unsigned> partitionIdx);
38+
};
39+
40+
// Assign hardware barriers to each warp group and rewrite warp group barriers
41+
// into named barrier instructions. There is a maximum number of named barriers.
42+
LogicalResult
43+
lowerWarpSpecializeBarriers(ModuleOp module,
44+
WarpSpecializeBarrierHelper &barrierHelper);
45+
2046
//===----------------------------------------------------------------------===//
2147
// convertOpTypes
2248
//===----------------------------------------------------------------------===//

include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ namespace mlir::triton {
2222
class FuncOp;
2323

2424
namespace instrument {
25+
std::string mangleType(Type t);
2526

2627
class ManglingArgs {
2728
public:
28-
using Arg = std::variant<Type, int, std::string>;
29+
using Arg = std::variant<Type, uint64_t, std::string>;
2930

3031
ManglingArgs() = default;
3132
ManglingArgs(const ManglingArgs &) = default;
@@ -51,9 +52,8 @@ class ManglingArgs {
5152

5253
std::string mangleArg(Arg arg) const {
5354
if (auto type = std::get_if<Type>(&arg)) {
54-
auto hash = static_cast<uint64_t>(mlir::hash_value(*type));
55-
return std::string("_T") + llvm::utohexstr(hash);
56-
} else if (auto intVal = std::get_if<int>(&arg)) {
55+
return std::string("_") + mangleType(*type);
56+
} else if (auto intVal = std::get_if<uint64_t>(&arg)) {
5757
return std::string("_I") + std::to_string(*intVal);
5858
} else if (auto stringVal = std::get_if<std::string>(&arg)) {
5959
return *stringVal;
@@ -74,18 +74,14 @@ class ManglingArgs {
7474
SmallVector<Arg> args;
7575
};
7676

77-
/// Utility to mangle helper function names produced by the instrumentation
78-
/// passes. The mangled name encodes the base name, number of warps and the
79-
/// participating types.
80-
std::string mangleInstrumentHelperName(const std::string &baseName,
81-
int numWarps,
82-
llvm::ArrayRef<Type> types);
83-
8477
class FunctionBuilder {
8578
public:
8679
FunctionBuilder(ModuleOp module, AuxDataMap &auxData)
8780
: module(module), auxData(auxData) {}
8881

82+
// Create a function that fills a global tensor with a scalar value.
83+
void createFillGlobalTensorCall(ImplicitLocOpBuilder &b, Value ptr,
84+
RankedTensorType type, Value scalar);
8985
// setWaiting: mark the base thread as waiting on the given barrier phase and
9086
// record that phase for deadlock detection.
9187
void createSetWaitingCall(ImplicitLocOpBuilder &b, Value mbar, int thread,

include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,17 @@ class TTI_Op<string mnemonic, list<Trait> traits = []> :
2121
Op<TritonInstrument_Dialect, mnemonic, traits> {
2222
}
2323

24-
def TTI_ExperimentalAssertInThreadOp : TTI_Op<"experimental_assert_in_thread", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
25-
let summary = "assert the condition within the current thread";
24+
def TTI_ExperimentalAssertUniformOp : TTI_Op<"experimental_assert_uniform", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
25+
let summary = "assert the uniform condition";
2626
let description = [{
27-
Assert that the condition is true given all the values are available in the current thread.
28-
If the condition is false, the message is printed, and the program is aborted.
29-
If check_any is true, any of the values in the condition must be true. Otherwise, all the
30-
values in the condition must be true.
27+
Assert that the condition is true given all threads in the warp group have
28+
the same value, so only one thread needs to evaluate the assert and print
29+
the message.
3130
}];
32-
let arguments = (ins AnyTypeOf<[I1, I1Tensor]>:$condition, StrAttr:$message, BoolAttr:$check_any);
33-
let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)";
31+
let arguments = (ins I1:$condition, StrAttr:$message);
32+
let assemblyFormat = "$condition `,` $message attr-dict-with-keyword";
3433
}
3534

36-
3735
def TTI_ExperimentalBufferDescriptorsOp
3836
: TTI_Op<"experimental_buffer_descriptors", [Pure]> {
3937
let summary = "define an array of buffer descriptors";

include/triton/Dialect/TritonInstrument/IR/Utility.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <array>
1010

1111
namespace mlir::triton::instrument {
12+
class FunctionBuilder;
1213

1314
constexpr int numMemTypes = getMaxEnumValForMemType() + 1;
1415

@@ -22,18 +23,22 @@ namespace CommitKind {
2223
enum Kind { None = -1, AsyncCp = 0, Wgmma, TmaStore, NumCommitKinds };
2324
}
2425

26+
void createAssertInThread(ImplicitLocOpBuilder &b, Value condition,
27+
StringRef message);
2528
Operation *createStoreScratchMemory(OpBuilder &b, Location loc, Value alloc,
2629
Value tensor, RankedTensorType tensorType);
2730
Value createLoadScratchMemory(OpBuilder &b, Location loc, Value alloc,
2831
RankedTensorType tensorType);
2932
Value expandOuterSlicedDim(OpBuilder &b, Location loc, Value tensor);
33+
RankedTensorType getIntTensorType(Region *region, ArrayRef<int64_t> shape,
34+
unsigned bitWidth);
3035
TypedValue<RankedTensorType> createConstIntTensor(OpBuilder &builder,
3136
Location loc, int64_t val,
3237
RankedTensorType tensorType,
3338
bool isSigned = false);
3439
FuncOp getEntryPoint(ModuleOp module);
3540
gpu::DistributedEncodingTrait
36-
getSingleDimSliceEncoding(gpu::BlockedEncodingAttr encoding, int dim);
41+
getSingleDimSliceEncoding(gpu::DistributedEncodingTrait encoding, int dim);
3742

3843
struct ValueType {
3944
Value value;
@@ -82,7 +87,8 @@ struct AuxDataMap {
8287
RegionToValueMap waiting;
8388
std::array<bool, numMemTypes> hasNonTrivialAliasing{};
8489

85-
void populateAndPassToWarpSpecialize(ModuleOp module);
90+
void populateAndPassToWarpSpecialize(ModuleOp module,
91+
FunctionBuilder &funcBuilder);
8692

8793
private:
8894
void getBuffersAndBarriers(

lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
3131
rewriter, loc, elemTy,
3232
rewriter.getZeroAttr(elemTy))));
3333
} else {
34-
assert(false && "Unsupported type for assert");
35-
return failure();
34+
return op->emitError("Unsupported type for assert");
3635
}
3736
}
3837
llAssert(op, condition, adaptor.getMessage(), rewriter);
@@ -49,11 +48,11 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
4948
}
5049
// op: the op at which the assert is inserted. Unlike printf, we need to
5150
// know about the op to split the block.
52-
void llAssert(Operation *op, Value condition, StringRef message,
51+
void llAssert(AssertOp op, Value condition, StringRef message,
5352
ConversionPatternRewriter &rewriter) const {
5453

55-
auto ctx = rewriter.getContext();
5654
auto loc = op->getLoc();
55+
auto b = TritonLLVMOpBuilder(loc, rewriter);
5756

5857
StringRef file = "unknown";
5958
StringRef func = "unknown";
@@ -72,24 +71,13 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
7271
col = fileLineColLoc.getColumn();
7372
}
7473

75-
// #block1
76-
// if (condition) {
77-
// #block2
78-
// __assertfail(message);
79-
// }
80-
// #block3
81-
Block *prevBlock = op->getBlock();
74+
auto [prevBlock, ifBlock, thenBlock] =
75+
createIfBlock(rewriter, loc, condition);
8276

83-
Block *ifBlock = rewriter.splitBlock(prevBlock, op->getIterator());
8477
rewriter.setInsertionPointToStart(ifBlock);
8578
targetInfo.assertFail(rewriter, loc, message, file, func, line);
8679

8780
// Split a block after the call.
88-
Block *thenBlock = rewriter.splitBlock(ifBlock, op->getIterator());
89-
rewriter.setInsertionPointToEnd(ifBlock);
90-
LLVM::BrOp::create(rewriter, loc, thenBlock);
91-
rewriter.setInsertionPointToEnd(prevBlock);
92-
LLVM::CondBrOp::create(rewriter, loc, condition, ifBlock, thenBlock);
9381
rewriter.setInsertionPointToStart(thenBlock);
9482
}
9583

lib/Conversion/TritonGPUToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_triton_library(TritonGPUToLLVM
55
AllocateSharedMemoryUtility.cpp
66
AllocateWarpGroups.cpp
77
AssertOpToLLVM.cpp
8+
CanonicalizeLLVMIR.cpp
89
ControlFlowOpToLLVM.cpp
910
ConvertLayoutOpToLLVM.cpp
1011
ElementwiseOpToLLVM.cpp

0 commit comments

Comments
 (0)