Skip to content

Commit ad5ae7a

Browse files
authored
Merge branch 'main' into sjw/tensdesc-names
2 parents ee6c54a + fb5c197 commit ad5ae7a

90 files changed

Lines changed: 1609 additions & 619 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/integration-tests-amd.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ jobs:
8585
~/.triton/json
8686
key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ steps.cache-key.outputs.llvm }}-nvidia-${{ steps.cache-key.outputs.nvidia }}-json-${{ steps.cache-key.outputs.json }}
8787
- name: Install dependencies
88-
run: apt-get install -y clang lld ccache
88+
run: apt-get update && apt-get install -y clang lld ccache
8989
- name: Inspect cache directories
9090
run: |
9191
mkdir -p ~/.triton

.github/workflows/wheels.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ permissions: read-all
1212
jobs:
1313

1414
Build-Wheels:
15-
timeout-minutes: 120
15+
timeout-minutes: 180
1616
runs-on: ${{ matrix.config.runs_on }}
1717

1818
strategy:
@@ -99,12 +99,12 @@ jobs:
9999
path: ./wheelhouse/*.whl
100100

101101
- name: Install Azure CLI
102-
if: ${{ steps.check-version.outputs.new_commit == 'true' }}
102+
if: ${{ steps.check-version.outputs.new_commit == 'true' && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') }}
103103
run: |
104104
curl -sL https://aka.ms/InstallAzureCLIDeb | sudo bash
105105
106106
- name: Azure login
107-
if: ${{ steps.check-version.outputs.new_commit == 'true' }}
107+
if: ${{ steps.check-version.outputs.new_commit == 'true' && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') }}
108108
uses: azure/login@v2
109109
with:
110110
client-id: ${{ secrets.AZURE_CLIENT_ID }}
@@ -113,20 +113,20 @@ jobs:
113113

114114
- id: generate-token
115115
name: Generate token
116-
if: ${{ steps.check-version.outputs.new_commit == 'true' }}
116+
if: ${{ steps.check-version.outputs.new_commit == 'true' && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') }}
117117
run: |
118118
AZ_TOKEN=$(az account get-access-token --query accessToken)
119119
echo "::add-mask::$AZ_TOKEN"
120120
echo "access_token=$AZ_TOKEN" >> "$GITHUB_OUTPUT"
121121
122122
- name: Publish wheels to Azure DevOps
123-
if: ${{ steps.check-version.outputs.new_commit == 'true' }}
123+
if: ${{ steps.check-version.outputs.new_commit == 'true' && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') }}
124124
run: |
125125
python3 -m pip install twine
126126
python3 -m twine upload -r Triton-Nightly -u TritonArtifactsSP -p ${{ steps.generate-token.outputs.access_token }} --config-file utils/nightly.pypirc --non-interactive --verbose wheelhouse/*
127127
128128
- name: Azure Logout
129-
if: ${{ steps.check-version.outputs.new_commit == 'true' && (success() || failure()) }}
129+
if: ${{ steps.check-version.outputs.new_commit == 'true' && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && (success() || failure()) }}
130130
run: |
131131
az logout
132132
az cache purge

include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,23 @@ class FunctionBuilder {
9393
// matching barrier phases.
9494
void createCheckAllActiveWaitingCall(ImplicitLocOpBuilder &b, int activeMask,
9595
Value pred, Operation *insertPoint);
96+
// verifyBarrierCanInit: ensure the barrier is currently invalidated before
97+
// initializing it again.
98+
void createVerifyBarrierCanInitCall(ImplicitLocOpBuilder &b, Value mbar,
99+
Operation *insertPoint);
100+
// verifyBarrierInitialized: ensure the barrier has been initialized and not
101+
// invalidated before it is used.
102+
void createVerifyBarrierInitializedCall(ImplicitLocOpBuilder &b, Value mbar,
103+
Value pred, Operation *insertPoint);
96104
// initBarrierState: Initialize the tracked barrier state to phase 0 and set
97-
// both the initial and current arrival counts.
105+
// both the initial and current arrival counts. A zero state denotes an
106+
// invalidated/uninitialized barrier.
98107
void createInitBarrierStateCall(ImplicitLocOpBuilder &b, Value mbar,
99108
int count, Operation *insertPoint);
109+
// invalidateBarrierState: clear the tracked barrier lifecycle state and any
110+
// waiting bits for the barrier.
111+
void createInvalidateBarrierStateCall(ImplicitLocOpBuilder &b, Value mbar,
112+
Operation *insertPoint);
100113
// verifyBarrierArrive: Check that applying the arrive count would not drive
101114
// the tracked current count negative. Triggers an assertion on failure.
102115
void createVerifyBarrierArriveCall(ImplicitLocOpBuilder &b, Value mbar,
@@ -145,6 +158,16 @@ class FunctionBuilder {
145158
void createTrackVisibleReadsCall(ImplicitLocOpBuilder &b, Value mbar,
146159
int thread, Value pred, MemType memType,
147160
Operation *insertPoint);
161+
// clearBarrierWriteTracking: clear all write tracking associated with the
162+
// given barrier row.
163+
void createClearBarrierWriteTrackingCall(ImplicitLocOpBuilder &b, Value mbar,
164+
Value pred, MemType memType,
165+
Operation *insertPoint);
166+
// clearBarrierReadTracking: clear all read tracking associated with the
167+
// given barrier row.
168+
void createClearBarrierReadTrackingCall(ImplicitLocOpBuilder &b, Value mbar,
169+
Value pred, MemType memType,
170+
Operation *insertPoint);
148171
// transferVisibleWrites: transfer write visibility tracked by a barrier to
149172
// all threads in threadMask.
150173
void createTransferVisibleWritesCall(ImplicitLocOpBuilder &b, Value mbar,

include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ LinearLayout getTileLayout(MLIRContext *ctx, TMemAccessAtom atom, bool unpacked,
117117

118118
TMemAllocation getTmemAllocSizes(gpu::MemDescType memDescType);
119119

120+
uint32_t getTMemSubSliceOffset(gpu::MemDescType memDescType, int32_t nOffset);
121+
120122
SmallVector<gpu::DistributedEncodingTrait>
121123
getTmemCompatibleLayouts(gpu::MemDescType memType, unsigned numWarps,
122124
ArrayRef<int64_t> ctaSplit = {1, 1});

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,8 @@ def TTNG_TMEMAllocOp : TTNG_Op<"tmem_alloc", [DeclareOpInterfaceMethods<MemoryEf
896896
}];
897897
}
898898

899-
def TTNG_TMEMSubSliceOp : TTNG_Op<"tmem_subslice", [Pure]> {
899+
def TTNG_TMEMSubSliceOp : TTNG_Op<"tmem_subslice", [Pure,
900+
MemDescViewTrait]> {
900901
let summary = "Take a subslice of a tensor memory allocation";
901902
let description = [{
902903
This operation takes a subslice of a tensor memory allocation and returns a new descriptor

lib/Analysis/BufferRegion.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ llvm::DenseSet<Value> getBarrierOperands(Operation *op) {
6363
if (auto initBarrierOp = dyn_cast<ttng::InitBarrierOp>(op)) {
6464
return {initBarrierOp.getOperand()};
6565
}
66+
if (auto waitBarrierOp = dyn_cast<ttng::WaitBarrierOp>(op)) {
67+
return {waitBarrierOp.getAlloc()};
68+
}
69+
if (auto arriveBarrierOp = dyn_cast<ttng::ArriveBarrierOp>(op)) {
70+
return {arriveBarrierOp.getAlloc()};
71+
}
6672
if (auto barrierExpectOp = dyn_cast<ttng::BarrierExpectOp>(op)) {
6773
return {barrierExpectOp.getAlloc()};
6874
}
@@ -269,7 +275,8 @@ LogicalResult BufferRegionAnalysis::visitOperation(
269275
if (auto tmemSubsliceOp = dyn_cast<ttng::TMEMSubSliceOp>(op)) {
270276
RegionInfo in = operands[0]->getValue();
271277
uint32_t subBufferSize = getMemDescSize(tmemSubsliceOp.getType());
272-
uint32_t relativeOffset = tmemSubsliceOp.getN();
278+
uint32_t relativeOffset = ttng::getTMemSubSliceOffset(
279+
tmemSubsliceOp.getType(), tmemSubsliceOp.getN());
273280
for (auto &region : in.regions) {
274281
regionInfo.regions.insert(
275282
{region.baseOffset + relativeOffset, subBufferSize});
@@ -326,8 +333,8 @@ bool BufferRegionAnalysis::isMemoryAccessOperation(Operation *op) {
326333
ttng::TMEMStoreOp, ttg::AsyncCopyGlobalToLocalOp,
327334
ttng::AsyncTMACopyGlobalToLocalOp, ttng::AsyncTMACopyLocalToGlobalOp,
328335
ttng::AsyncTMAGatherOp, ttng::AsyncTMAScatterOp, ttng::InitBarrierOp,
329-
ttng::BarrierExpectOp, ttng::InvalBarrierOp, ttng::WaitBarrierOp>(
330-
op)) {
336+
ttng::BarrierExpectOp, ttng::InvalBarrierOp, ttng::WaitBarrierOp,
337+
ttng::ArriveBarrierOp>(op)) {
331338
return true;
332339
}
333340
// Allocations with operands write to the memory.

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -549,13 +549,9 @@ struct MemDescSubsliceOpConversion
549549
matchAndRewrite(triton::gpu::MemDescSubsliceOp op, OpAdaptor adaptor,
550550
ConversionPatternRewriter &rewriter) const override {
551551
Location loc = op->getLoc();
552-
auto *ctx = op->getContext();
553552
auto b = TritonLLVMOpBuilder(loc, rewriter);
554553
auto srcTy = op.getSrc().getType();
555-
auto destTy = op.getResult().getType();
556554
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
557-
auto layoutOrder = getOrder(srcTy);
558-
auto enc = srcTy.getEncoding();
559555

560556
// PartitionedSharedEncoding is not yet supported for memdesc_subslice
561557
if (isa<PartitionedSharedEncodingAttr>(srcTy.getEncoding())) {

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,9 @@ LogicalResult MemDescIndexOp::verify() {
881881
if (srcTy.getElementType() != dstTy.getElementType()) {
882882
return emitError("result element type must match desc element type");
883883
}
884+
if (srcTy.getEncoding() != dstTy.getEncoding()) {
885+
return emitError("src and result must have the same encoding");
886+
}
884887
// memdesc_index reduces rank by 1 and preserves the trailing shape.
885888
bool correctRank = srcTy.getRank() == dstTy.getRank() + 1;
886889
if (!correctRank) {
@@ -955,6 +958,9 @@ LogicalResult MemDescSubsliceOp::verify() {
955958
if (srcTy.getElementType() != dstTy.getElementType()) {
956959
return emitError("result element type must match desc element type");
957960
}
961+
if (srcTy.getEncoding() != dstTy.getEncoding()) {
962+
return emitError("src and result must have the same encoding");
963+
}
958964
if (getOffsets().size() != srcTy.getRank()) {
959965
return emitError("offsets must have the same rank as input");
960966
}
@@ -993,6 +999,9 @@ LogicalResult MemDescSubsliceOp::verify() {
993999
if (offset & (dstTy.getDimSize(dim) - 1)) {
9941000
return emitError("The split offset may not touch the tile");
9951001
}
1002+
if (offset >= srcTy.getDimSize(dim)) {
1003+
return emitError("The split offset may not exceed the source shape");
1004+
}
9961005
}
9971006
}
9981007

lib/Dialect/TritonGPU/IR/Types.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,17 +128,20 @@ LogicalResult MemDescType::verify(function_ref<InFlightDiagnostic()> emitError,
128128
<< "bitwidth * colStride must be less than or equal to 32. Got "
129129
<< bitwidth << " and " << enc.getColStride();
130130
}
131-
shape = shape.take_back(2);
131+
// Takes subslices into account and figures out whether we can construct
132+
// the linear layout at all
132133
allocShape = allocShape.take_back(2);
133134
auto ctaSplit = enc.getCGALayout().getCTASplitNum();
135+
auto blockN = std::min<int32_t>(enc.getBlockN(), shape.back());
134136
if (allocShape[0] < enc.getBlockM() * ctaSplit[0] ||
135-
allocShape[1] < enc.getBlockN() * ctaSplit[1]) {
137+
allocShape[1] < blockN * ctaSplit[1]) {
136138
return emitError() << "the allocation shape must be at least "
137139
<< enc.getBlockM() * ctaSplit[0] << "x"
138-
<< enc.getBlockN() * ctaSplit[1] << ". Got "
139-
<< allocShape;
140+
<< blockN * ctaSplit[1] << ". Got " << allocShape;
140141
}
142+
// Checks the layout of the allocation
141143
auto ll = toLinearLayout(allocShape, enc);
144+
// Sanity check that the layout is of the right shape
142145
auto dims = standardOutDimNames(ctx, 2);
143146
if (ll.getOutDimSize(dims[0]) != allocShape[0] ||
144147
ll.getOutDimSize(dims[1]) != allocShape[1]) {

0 commit comments

Comments
 (0)