Skip to content

Commit 4027076

Browse files
committed
Merge remote-tracking branch 'origin/main' into phil/cleanup-scale
2 parents 3697265 + df38505 commit 4027076

110 files changed

Lines changed: 6729 additions & 1572 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/integration-tests-nvidia.yml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@ on:
99

1010
jobs:
1111
integration-tests-nvidia:
12-
runs-on: ${{ matrix.runner }}
12+
name: integration-tests-nvidia (${{ matrix.config.name }})
13+
runs-on: ${{ matrix.config.runs_on }}
1314
timeout-minutes: 60
1415
# Let A100 and H100 continue even if GB200 fails, as it's a bit flaky
15-
continue-on-error: ${{ matrix.runner[0] == 'nvidia-gb200'}}
16+
continue-on-error: ${{ startsWith(matrix.config.runner_type, 'nvidia-gb200') }}
1617
strategy:
1718
matrix:
18-
runner: ${{ fromJson(inputs.matrix) }}
19+
config: ${{ fromJson(inputs.matrix) }}
1920
env:
20-
RUNNER_TYPE: ${{ matrix.runner[0] }}
21+
RUNNER_TYPE: ${{ matrix.config.runner_type }}
2122
TRITON_BUILD_WITH_CCACHE: "true"
2223
TRITON_BUILD_WITH_CLANG_LLD: "TRUE"
2324
TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE"
@@ -69,7 +70,7 @@ jobs:
6970
run: |
7071
echo "$HOME/.local/bin" >> $GITHUB_PATH
7172
- name: Setup Python environment for GB200
72-
if: ${{ matrix.runner[0] == 'nvidia-gb200' }}
73+
if: ${{ startsWith(matrix.config.runner_type, 'nvidia-gb200') }}
7374
run: |
7475
echo "/venv/bin" >> $GITHUB_PATH
7576
echo "VIRTUAL_ENV=/venv" >> $GITHUB_ENV
@@ -90,7 +91,7 @@ jobs:
9091
- name: Run python tests on CUDA
9192
run: make NUM_PROCS=24 test-unit
9293
- name: Run interpreter tests
93-
if: ${{ matrix.runner[0] == 'nvidia-h100' }}
94+
if: ${{ matrix.config.runner_type == 'nvidia-h100' }}
9495
run: make test-interpret
9596
- name: Run regression tests
9697
run: make test-regression

.github/workflows/runner-preparation.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ jobs:
9595
if: env.enable_integration == 'true'
9696
run: |
9797
if [ x"${{ github.repository }}" == x"triton-lang/triton" ]; then
98-
echo '::set-output name=matrix-NVIDIA::[["nvidia-a100"], ["nvidia-h100"], ["nvidia-gb200"]]'
98+
echo '::set-output name=matrix-NVIDIA::[{"name":"nvidia-a100","runner_type":"nvidia-a100","runs_on":["nvidia-a100"]},{"name":"nvidia-h100","runner_type":"nvidia-h100","runs_on":["nvidia-h100"]},{"name":"nvidia-gb200","runner_type":"nvidia-gb200","runs_on":{"group":"gb200-runner-set"}}]'
9999
echo '::set-output name=matrix-AMD::[["self-hosted", "gfx90a"], ["amd-gfx942"], ["amd-gfx950"]]'
100100
echo '::set-output name=matrix-MACOS::[["macos-latest"]]'
101101
else
102-
echo '::set-output name=matrix-NVIDIA::["ubuntu-latest"]'
102+
echo '::set-output name=matrix-NVIDIA::[{"name":"ubuntu-latest","runner_type":"ubuntu-latest","runs_on":"ubuntu-latest"}]'
103103
echo '::set-output name=matrix-AMD::["ubuntu-latest"]'
104104
echo '::set-output name=matrix-MACOS::[["macos-latest"]]'
105105
fi

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8686
mlir::triton::gpu::registerAllocateSharedMemoryPass();
8787
mlir::triton::gpu::registerTritonGPUAllocateWarpGroups();
8888
mlir::triton::gpu::registerTritonGPUGlobalScratchAllocationPass();
89+
mlir::triton::gpu::registerCanonicalizeLLVMIR();
8990
mlir::triton::registerConvertWarpSpecializeToLLVM();
9091
mlir::triton::registerConvertTritonGPUToLLVMPass();
9192
mlir::triton::registerConvertNVGPUToLLVMPass();

include/triton/Analysis/Allocation.h

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ using AllocationAnalysisScratchSizeFn = std::function<unsigned(Operation *)>;
2020

2121
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op);
2222

23+
unsigned getNumScratchElemsSwizzledCvt(const LinearLayout &srcLayout,
24+
const LinearLayout &dstLayout,
25+
int bitwidth);
26+
2327
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
2428
RankedTensorType dstTy);
2529

@@ -70,8 +74,11 @@ class Allocation {
7074
explicit Allocation(Operation *operation) : operation(operation) {}
7175

7276
/// Runs allocation analysis on the given top-level operation.
77+
/// \param sharedMemoryPartitionSize The size of each shared memory partition
78+
/// in bytes. A value of 0 means shared memory is not partitioned.
7379
void run(FuncAllocMapT &funcAllocMap,
74-
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter);
80+
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter,
81+
size_t sharedMemoryPartitionSize = 0);
7582

7683
/// Returns the operation this analysis was constructed from.
7784
Operation *getOperation() const { return operation; }
@@ -92,24 +99,29 @@ class Allocation {
9299
return Interval<size_t>(buffer.offset, buffer.offset + buffer.size);
93100
}
94101

95-
/// Returns the buffer id of the given value.
96-
/// This interface only returns the allocated buffer id.
97-
/// If you want to get all the buffer ids that are associated with the given
98-
/// value, including alias buffers, use getBufferIds.
99-
BufferId getBufferId(Value value) const {
100-
if (valueBuffer.count(value)) {
101-
return valueBuffer.lookup(value)->id;
102-
} else {
103-
return InvalidBufferId;
102+
/// Returns all buffer ids for a value.
103+
/// For partitioned tensors, returns all logical piece buffer ids.
104+
/// For non-partitioned values, returns a single-element vector.
105+
/// Returns empty vector if value has no associated buffer.
106+
SmallVector<BufferId> getBufferIds(Value value) const {
107+
SmallVector<BufferId> bufferIds;
108+
auto it = valueBuffer.find(value);
109+
if (it == valueBuffer.end())
110+
return bufferIds;
111+
112+
for (auto *buffer : it->second) {
113+
bufferIds.push_back(buffer->id);
104114
}
115+
return bufferIds;
105116
}
106117

107-
/// Returns all the buffer ids of the given value, including alias buffers.
108-
BufferIdSetT getBufferIds(Value value) const {
118+
/// Returns all buffer ids of the given value, including alias buffers.
119+
/// This is a superset of getBufferIds that also includes aliased buffers.
120+
BufferIdSetT getAllBufferIdsWithAliases(Value value) const {
109121
BufferIdSetT bufferIds;
110-
auto allocBufferId = getBufferId(value);
111-
if (allocBufferId != InvalidBufferId)
112-
bufferIds.insert(allocBufferId);
122+
for (auto bufferId : getBufferIds(value)) {
123+
bufferIds.insert(bufferId);
124+
}
113125
for (auto *buffer : aliasBuffer.lookup(value)) {
114126
if (buffer->id != InvalidBufferId)
115127
bufferIds.insert(buffer->id);
@@ -154,6 +166,10 @@ class Allocation {
154166
size_t alignment;
155167
size_t offset;
156168

169+
/// For partitioned tensors: buffers that reside in different physical
170+
/// partitions.
171+
SmallVector<BufferT *> neighbors;
172+
157173
bool operator==(const BufferT &other) const { return id == other.id; }
158174
bool operator<(const BufferT &other) const { return id < other.id; }
159175

@@ -169,8 +185,8 @@ class Allocation {
169185

170186
/// Op -> Scratch Buffer
171187
using OpScratchMapT = llvm::MapVector<Operation *, BufferT *>;
172-
/// Value -> Explicit Buffer
173-
using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
188+
/// Value -> Explicit Buffers (vector for partitioned tensors)
189+
using ValueBufferMapT = llvm::MapVector<Value, SmallVector<BufferT *>>;
174190
/// Value -> Alias Buffer
175191
using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
176192
/// BufferId -> Buffer
@@ -184,16 +200,28 @@ class Allocation {
184200
nextId, BufferT(Kind, nextId, key, std::forward<Args>(args)...));
185201
BufferT *buffer = &it->second;
186202
if constexpr (Kind == BufferT::BufferKind::Explicit) {
187-
valueBuffer[key] = buffer;
203+
valueBuffer[key].push_back(buffer);
188204
} else if constexpr (Kind == BufferT::BufferKind::Virtual) {
189205
opVirtual[key] = buffer;
190206
} else {
191207
opScratch[key] = buffer;
192208
}
193209
}
194210

211+
/// Create multiple buffers for partitions where all different partitions
212+
/// are neighbors (must be placed in different physical shared memory slots).
213+
///
214+
/// \param key The value that owns these buffers
215+
/// \param numPartitions Number of partition buffers to create
216+
/// \param partitionSize Size of each partition buffer in bytes
217+
/// \param alignment Required alignment for each buffer
218+
void addPartitionBuffers(Value key, unsigned numPartitions,
219+
size_t partitionSize, size_t alignment);
220+
195221
void addAlias(Value value, Value alloc) {
196-
aliasBuffer[value].insert(valueBuffer[alloc]);
222+
for (auto *buffer : valueBuffer[alloc]) {
223+
aliasBuffer[value].insert(buffer);
224+
}
197225
}
198226

199227
private:
@@ -222,7 +250,8 @@ class ModuleAllocation : public triton::CallGraph<Allocation> {
222250

223251
ModuleAllocation(ModuleOp moduleOp,
224252
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter =
225-
triton::defaultAllocationAnalysisScratchSizeFn)
253+
triton::defaultAllocationAnalysisScratchSizeFn,
254+
size_t sharedMemoryPartitionSize = 0)
226255
: triton::CallGraph<Allocation>(moduleOp) {
227256
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
228257
// Pre-order edge walk callback
@@ -231,7 +260,8 @@ class ModuleAllocation : public triton::CallGraph<Allocation> {
231260
[&](FunctionOpInterface funcOp) {
232261
auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp);
233262
if (inserted)
234-
iter->second.run(funcMap, scratchSizeGetter);
263+
iter->second.run(funcMap, scratchSizeGetter,
264+
sharedMemoryPartitionSize);
235265
});
236266
}
237267

include/triton/Analysis/Utility.h

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "mlir/Analysis/DataFlowFramework.h"
55
#include "mlir/Analysis/SliceAnalysis.h"
6+
#include "mlir/IR/Builders.h"
67
#include "mlir/Support/LLVM.h"
78
#include "triton/Dialect/Triton/IR/Dialect.h"
89
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -24,6 +25,22 @@ inline bool isZeroConst(Value v) {
2425

2526
class ReduceOpHelper {
2627
public:
28+
enum class InThreadVectorizeOpKind {
29+
None,
30+
AddF,
31+
MulF,
32+
MinNumF,
33+
MaxNumF,
34+
MinimumF,
35+
MaximumF,
36+
AddI,
37+
MulI,
38+
MinSI,
39+
MaxSI,
40+
MinUI,
41+
MaxUI,
42+
};
43+
2744
explicit ReduceOpHelper(triton::ReduceOp op)
2845
: op(op.getOperation()), axis(op.getAxis()) {
2946
auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType());
@@ -42,30 +59,41 @@ class ReduceOpHelper {
4259
}
4360
}
4461

45-
ArrayRef<int64_t> getSrcShape() { return srcShape; }
62+
RankedTensorType getSrcTy() { return srcTy; }
4663

47-
Attribute getSrcLayout() { return srcEncoding; }
64+
unsigned getInterWarpSizeWithUniqueData();
4865

49-
triton::ReduceOp getOperation() { return op; }
66+
unsigned getIntraWarpSizeWithUniqueData();
5067

51-
unsigned getThreadOffsetOnReductionAxis();
68+
bool isReduceWithinCTA();
5269

53-
bool isWarpSynchronous();
70+
bool isAssociative();
5471

55-
unsigned getInterWarpSizeWithUniqueData();
72+
unsigned getScratchSizeInBytes();
5673

57-
unsigned getIntraWarpSizeWithUniqueData();
74+
InThreadVectorizeOpKind
75+
getInThreadVectorizeOpKind(unsigned axisPack,
76+
bool supportBitwidth16Elementwise,
77+
bool supportBitwidth32Elementwise);
5878

59-
// The shape of the shared memory space needed for the reduction.
60-
SmallVector<unsigned> getScratchRepShape();
79+
static triton::ColumnAction
80+
moveAxisBasesToFront(const triton::LinearLayout &layout, int axis,
81+
bool isVectorized = false);
6182

62-
SmallVector<unsigned> getOrderWithAxisAtBeginning();
83+
static triton::LinearLayout
84+
zeroBasesAlongDimAndReorder(const triton::LinearLayout &layout, unsigned axis,
85+
mlir::StringAttr dim);
6386

64-
unsigned getScratchSizeInBytes();
87+
static triton::LinearLayout getInterLayout(const triton::LinearLayout &layout,
88+
unsigned axis);
6589

66-
bool isReduceWithinCTA();
90+
static triton::LinearLayout reducedRegLaneLayout(RankedTensorType srcTy,
91+
unsigned axis);
6792

68-
bool isAssociative();
93+
static Value createInThreadVectorizedCombineOp(OpBuilder &builder,
94+
Location loc,
95+
InThreadVectorizeOpKind kind,
96+
Value lhs, Value rhs);
6997

7098
private:
7199
triton::ReduceOp op;

include/triton/Conversion/TritonGPUToLLVM/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,8 @@ def TritonGPUAllocateWarpGroups : Pass<"tritongpu-allocate-warp-groups", "mlir::
4242
}];
4343
}
4444

45+
def CanonicalizeLLVMIR : Pass<"canonicalize-llvm-ir", "mlir::LLVM::LLVMFuncOp"> {
46+
let summary = "Canonicalize LLVM IR";
47+
}
48+
4549
#endif

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,7 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter,
103103
PatternBenefit benefit);
104104

105105
void populateInstrumentationToLLVMPatterns(LLVMTypeConverter &typeConverter,
106-
const TargetInfoBase &targetInfo,
107-
RewritePatternSet &patterns,
108-
PatternBenefit benefit);
106+
RewritePatternSet &patterns);
109107

110108
} // namespace triton
111109
} // namespace mlir

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H
33

44
#include "triton/Conversion/MLIRTypes.h"
5+
#include "llvm/ADT/ArrayRef.h"
56

67
namespace mlir::triton {
78
enum class ProgramIDDim : uint32_t;
@@ -66,8 +67,7 @@ class TargetInfoBase {
6667

6768
virtual bool warpReduce(RewriterBase &rewriter, Location loc,
6869
SmallVector<Value> &acc, triton::ReduceOp op,
69-
unsigned numLaneToReduce,
70-
unsigned interleave) const = 0;
70+
unsigned reduceLaneIdMask) const = 0;
7171

7272
virtual std::string getMulhiFuncName(Type resultElementTy) const = 0;
7373
// Emits LLVM code with |rewriter| to print a message following the given
@@ -102,8 +102,14 @@ class TargetInfoBase {
102102
virtual bool supportLdMatrix() const { return false; }
103103
virtual bool supportStMatrix() const { return false; }
104104
virtual bool supportLdStMatrixB8() const { return false; }
105+
virtual bool supportBitwidth16Elementwise() const { return false; }
106+
virtual bool supportBitwidth32Elementwise() const { return false; }
105107
virtual bool isCuda() const { return false; }
106108

109+
// Returns the shared memory partition size in bytes. A value of 0 means
110+
// shared memory is not partitioned.
111+
virtual size_t getSharedMemoryPartitionSize() const { return 0; }
112+
107113
// Annotate target specific information to local load operations during
108114
// lowering to LLVM. `llLoadOp` is the generated LLVM load op.
109115
virtual void localLoadOpAnnotation(triton::gpu::LocalLoadOp localLoadOp,

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,9 @@ LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
342342

343343
// Multiply a square layout with 1 input and output dimension with a vector
344344
Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x);
345+
346+
// Whether the convert layout should be forced to use warp shuffles.
347+
bool cvtAlwaysUseWarpShuffle(triton::gpu::ConvertLayoutOp cvt);
345348
} // namespace gpu
346349

347350
} // namespace triton
@@ -442,6 +445,9 @@ Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
442445
size_t linearize(ArrayRef<unsigned> multiDim, ArrayRef<unsigned> shape,
443446
ArrayRef<unsigned> order);
444447

448+
GlobalOp getOrInsertGlobalConstant(RewriterBase &rewriter, ModuleOp module,
449+
Type type, Attribute content, StringRef key);
450+
445451
Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key,
446452
StringRef content);
447453

@@ -630,6 +636,14 @@ SmallVector<Value> inlineRegion(RewriterBase &rewriter, Region &region,
630636
mlir::TypeID::get<TerminatorOp>(), loc);
631637
}
632638

639+
// #prevBlock
640+
// if (condition) {
641+
// #ifBlock
642+
// }
643+
// #thenBlock
644+
std::tuple</*prevBlock=*/Block *, /*ifBlock=*/Block *, /*thenBlock=*/Block *>
645+
createIfBlock(ConversionPatternRewriter &b, Location loc, Value cnd);
646+
633647
void finalizeTensorAtomicResults(Operation *op, RankedTensorType tensorTy,
634648
ConversionPatternRewriter &rewriter,
635649
SmallVector<Value> &resultVals,

0 commit comments

Comments
 (0)