Skip to content

Commit 49d447c

Browse files
committed
Merge denzel and ethan work, and squash
1 parent c6b8da1 commit 49d447c

199 files changed

Lines changed: 204105 additions & 71384 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
2+
.working/
3+
output/
14
# Triton builds
25
build/
36
build-*/

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
|-------------------- | -------------------- |
77
| [![Documentation](https://github.com/triton-lang/triton/actions/workflows/documentation.yml/badge.svg)](https://triton-lang.org/) | [![Wheels](https://github.com/triton-lang/triton/actions/workflows/wheels.yml/badge.svg?branch=release/2.0.x)](https://github.com/triton-lang/triton/actions/workflows/wheels.yml) |
88

9+
# COMS E6998 Runtime Analysis Project
10+
11+
See [our readme](runtime-analysis/README.md) in the `runtime-analysis` folder for details.
12+
913
# Triton
1014

1115
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs.

bin/benchmark-lower.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#include "./RegisterTritonDialects.h"
2+
3+
#include "mlir/Tools/mlir-reduce/MlirReduceMain.h"
4+
5+
int main(int argc, char **argv) {
6+
mlir::DialectRegistry registry;
7+
registerTritonDialects(registry);
8+
9+
mlir::MLIRContext context(registry);
10+
return mlir::failed(mlir::mlirReduceMain(argc, argv, context));
11+
}

commands.md

Lines changed: 0 additions & 49 deletions
This file was deleted.

empty_kernel.cpp

Lines changed: 492 additions & 0 deletions
Large diffs are not rendered by default.

include/triton/Target/LLVMIR/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ namespace mlir {
99
std::unique_ptr<Pass> createLLVMDIScopePass();
1010

1111
std::unique_ptr<Pass> createLLVMRuntimeAnalysisPass();
12+
std::unique_ptr<Pass> createLLVMRuntimeAnalysisPass(const std::string &analysisResultFile);
1213

1314
/// Generate the code for registering conversion passes.
1415
#define GEN_PASS_REGISTRATION

include/triton/Target/LLVMIR/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ def LLVMRuntimeAnalysis: Pass<"runtime-analysis", "mlir::ModuleOp"> {
2020
}];
2121

2222
let constructor = "mlir::createLLVMRuntimeAnalysisPass()";
23+
24+
let options = [
25+
Option<"result_file", "result_file", "std::string", /*default*/"\"\"",
26+
"path to place analysis results">,
27+
];
2328
}
2429

2530
#endif

lib/Analysis/Runtime.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// Denzel / Ethan Project, Analyze per-kernel Runtime
2+
3+
4+
5+
// Components:
6+
// - kernel launch overhead, depends on parameters for kernel launch (number of threads, number of blocks per thread)
7+
// - Do some profiling, build a linear regression
8+
// - compute-bound instruction latency
9+
// - Do some profiling, build a linear regression

lib/Conversion/TritonGPUToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ add_triton_library(TritonGPUToLLVM
2121
SPMDOpToLLVM.cpp
2222
DecomposeUnsupportedConversions.cpp
2323
PrintOpToLLVM.cpp
24+
LoopHintPass.cpp
2425

2526
DEPENDS
2627
TritonGPUConversionPassIncGen
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
2+
#include "mlir/Dialect/Arith/IR/Arith.h"
3+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
4+
#include "mlir/Dialect/Func/IR/FuncOps.h"
5+
#include "mlir/Pass/Pass.h"
6+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
7+
8+
using namespace mlir;
9+
10+
/// Matches scf -> cf loop pattern:
11+
/// %cmp = arith.cmpi slt, %lhs, %const : i32
12+
/// cf.cond_br %cmp, ^loop, ^exit
13+
///
14+
/// and annotates the `cf.cond_br` with loop iteration count.
15+
16+
namespace {
17+
18+
struct AnnotateLoopCondBrPattern : public OpRewritePattern<cf::CondBranchOp> {
19+
using OpRewritePattern<cf::CondBranchOp>::OpRewritePattern;
20+
21+
LogicalResult matchAndRewrite(cf::CondBranchOp condBrOp,
22+
PatternRewriter &rewriter) const override {
23+
24+
auto cmpOp = condBrOp.getCondition().getDefiningOp<arith::CmpIOp>();
25+
if (!cmpOp)
26+
return failure();
27+
28+
if (cmpOp.getPredicate() != arith::CmpIPredicate::slt)
29+
return failure();
30+
31+
auto rhsOp = cmpOp.getRhs().getDefiningOp<arith::ConstantOp>();
32+
if (!rhsOp)
33+
return failure();
34+
35+
auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(rhsOp.getValue());
36+
if (!intAttr)
37+
return failure();
38+
39+
int64_t loopCount = intAttr.getInt();
40+
41+
rewriter.setInsertionPointAfter(condBrOp);
42+
condBrOp->setAttr("loop_count", rewriter.getI64IntegerAttr(loopCount));
43+
44+
return success();
45+
}
46+
};
47+
48+
/// A pass that runs on a function and looks for any cf.cond_br that
49+
/// implements a loop condition (in the pattern above), then annotates
50+
/// it with `loop_count`.
51+
struct AnnotateLoopCountPass
52+
: public PassWrapper<AnnotateLoopCountPass, OperationPass<func::FuncOp>> {
53+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnnotateLoopCountPass)
54+
55+
StringRef getArgument() const final { return "annotate-loop-count"; }
56+
StringRef getDescription() const final {
57+
return "Annotate cf.cond_br ops with loop iteration counts.";
58+
}
59+
60+
void runOnOperation() override {
61+
auto function = getOperation();
62+
63+
RewritePatternSet patterns(&getContext());
64+
patterns.add<AnnotateLoopCondBrPattern>(&getContext());
65+
66+
if (failed(applyPatternsAndFoldGreedily(function.getBody(),
67+
std::move(patterns)))) {
68+
signalPassFailure();
69+
}
70+
}
71+
};
72+
73+
} // namespace
74+
75+
std::unique_ptr<Pass> createAnnotateLoopCountPass() {
76+
return std::make_unique<AnnotateLoopCountPass>();
77+
}

0 commit comments

Comments
 (0)