Skip to content

Commit a303a03

Browse files
authored
[Consan] Support CLC (#10052)
CLC gets its own partition, running over threads 48-63. We model CLC as we model TMA writes, via a Barrier::EffectWrites. The idea of this mode is that we link all the writes on the op to the barrier. We also annotate in the table `barrierWriteRecipients` which CTAs will become visible once we wait on the associated barrier. We note something interesting and document it. `BarrierTrackingMode::Frontier` should be used when we have a commit/arrive/expect op that affects anything in flight before it. Instead, we use `BarrierTrackingMode::EffectWrites` when the PTX op accepts a barrier so the barrier just signals the completion of the op's particular write. The other point we add is a flag `bool diagonalEffectRecipientCTAs`. This differentiates the behaviour between TMA, where after waiting on the barrier you see all the writes from all the CTAs in the multicas group, vs. the diagonal version, as in CLC, where waiting on CTAi just makes the thread see the CTAi memory.
1 parent 0ee2ec2 commit a303a03

14 files changed

Lines changed: 241 additions & 47 deletions

File tree

include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,16 @@ class FunctionBuilder {
169169
int thread, Value pred, MemType memType,
170170
Operation *insertPoint, Value recipientCTAs);
171171
// trackBarrierWriteForBuffer: mark a specific buffer as tracked by a
172-
// barrier in the write-tracking table.
172+
// barrier in the write-tracking table. When diagonalEffectRecipientCTAs is
173+
// false, every signaled barrier row publishes the full effectRecipientCTAs
174+
// mask. When it is true, barrier row i publishes only bit i of that mask.
173175
void createTrackBarrierWriteForBufferCall(ImplicitLocOpBuilder &b, Value mbar,
174176
Value buf, uint32_t length,
175177
Value pred, MemType memType,
176178
Operation *insertPoint,
177179
Value barrierRecipientCTAs,
178-
Value effectRecipientCTAs);
180+
Value effectRecipientCTAs,
181+
bool diagonalEffectRecipientCTAs);
179182
// clearBarrierWriteTracking: clear all write tracking associated with the
180183
// given barrier row.
181184
void createClearBarrierWriteTrackingCall(ImplicitLocOpBuilder &b, Value mbar,

include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ Auxiliary state is kept in distributed tensors and global scratch memory, with t
99
### Thread model
1010

1111
- Base threads: 16 warp-specialization (WS) threads (allowing for up to 16 partitions).
12-
- Peer classes: +16 Tensor Core (TC) threads and +16 TMA threads to model lack of ordering with base threads.
13-
- Total logical threads: 48. Bitmasks are sized to the next power of two: 64.
12+
- Peer classes: +16 TMA threads, +16 Tensor Core (TC) threads, and +16 CLC threads to model lack of ordering with base threads.
13+
- Total logical threads: 64.
1414

15-
Indexing uses a logical thread id in [0, 48), with column vectors sized to 64 for layout convenience.
15+
Indexing uses a logical thread id in [0, 64), with column vectors sized to 64 for layout convenience.
1616

1717
## Auxiliary data structures
1818

@@ -21,7 +21,7 @@ All types are generated on-demand (per partition) based on:
2121
- B: number of tracked buffers (power-of-two padded)
2222
- K: number of mbarriers (power-of-two padded)
2323
- T_bits: 64 (bitmask width)
24-
- T_commits: 16 (base threads; commit counters do not apply to TC/TMA helpers)
24+
- T_commits: 16 (base threads; commit counters do not apply to TC/TMA/CLC helpers)
2525

2626
“tensor” means a distributed Triton tensor; “scratch” means a pointer into global scratch memory. Shapes below are logical; actual encodings are partition-local blocked layouts.
2727

@@ -53,7 +53,7 @@ ConSan separates “tracking” from “visibility transfer”:
5353
- experimental_set_read_visibility / experimental_set_write_visibility updates the appropriate visibility table for the current thread and buffer.
5454
- experimental_track_visible_reads / experimental_track_visible_writes snapshots current per-buffer visibility into readTracking/writeTracking for the given barrier.
5555
- At arrive/commit sites (e.g., tc commit, arrive on mbarrier): ConSan emits the track ops for both reads and writes.
56-
- At waits: experimental_transfer_visible_reads / experimental_transfer_visible_writes propagates tracked visibility from the barrier back into the waiting thread’s visibility, and this transfer is repeated to peer threads (base, TMA, TC) to keep the three classes consistent.
56+
- At waits: experimental_transfer_visible_reads / experimental_transfer_visible_writes propagates tracked visibility from the barrier back into the waiting thread’s visibility, and this transfer is repeated to peer threads (base, TMA, TC, CLC) to keep the classes consistent.
5757

5858
### Barrier phase/count tracking
5959

include/triton/Dialect/TritonInstrument/IR/Utility.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "triton/Dialect/Triton/IR/Utility.h"
77
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
88
#include "triton/Dialect/TritonInstrument/IR/Dialect.h"
9+
#include "llvm/Support/MathExtras.h"
910

1011
#include <array>
1112

@@ -22,8 +23,11 @@ constexpr int numMemTypes = getMaxEnumValForMemType() + 1;
2223
constexpr int NUM_THREADS = 16;
2324
constexpr int TMA_THREAD_OFFSET = NUM_THREADS;
2425
constexpr int TC_THREAD_OFFSET = TMA_THREAD_OFFSET + NUM_THREADS;
25-
constexpr int TOTAL_NUM_THREADS = TC_THREAD_OFFSET + NUM_THREADS;
26-
constexpr int THREADS_BITMASK_SIZE = llvm::NextPowerOf2(TOTAL_NUM_THREADS);
26+
constexpr int CLC_THREAD_OFFSET = TC_THREAD_OFFSET + NUM_THREADS;
27+
constexpr int TOTAL_NUM_THREADS = CLC_THREAD_OFFSET + NUM_THREADS;
28+
static_assert(TOTAL_NUM_THREADS <= 64,
29+
"ConSan thread bitsets are stored in i64 masks");
30+
const int THREADS_BITMASK_SIZE = llvm::PowerOf2Ceil(TOTAL_NUM_THREADS);
2731

2832
namespace CommitKind {
2933
enum Kind { None = -1, AsyncCp = 0, Wgmma, TmaStore, NumCommitKinds };

include/triton/Dialect/TritonInstrument/Transforms/ConSanTargetHooks.h

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,19 @@
1212
namespace mlir::triton::instrument {
1313

1414
struct MemEffectsOpInfo {
15-
// Frontier: snapshot thread-visible frontier into barrier tracking.
16-
// EffectWrites: track only buffers written by op effects.
15+
// Controls which memory effects become visible to a CTA after it waits on
16+
// this barrier.
17+
//
18+
// Frontier snapshots the issuing thread's current visibility frontier into
19+
// the barrier. A later wait publishes whatever shared/tensor memory writes
20+
// and reads were visible to that logical thread before the arrive/commit. Use
21+
// this for ordering operations whose semantics are a release of prior work.
22+
//
23+
// EffectWrites does not snapshot the whole thread frontier. Instead, it
24+
// attaches only the explicit write effects of this op to the barrier. A later
25+
// wait publishes those op-local writes and nothing else. Use this for PTX ops
26+
// that perform the write and also signal the barrier via
27+
// `mbarrier::complete_tx`.
1728
enum class BarrierTrackingMode {
1829
Frontier,
1930
EffectWrites,
@@ -34,6 +45,18 @@ struct MemEffectsOpInfo {
3445
int count;
3546
BarrierTrackingMode trackingMode = BarrierTrackingMode::Frontier;
3647
int txCount = 0;
48+
// For EffectWrites, effectRecipientCTAs identifies the CTA rows where the
49+
// op wrote its explicit result. By default, for
50+
// diagonalEffectRecipientCTAs=false, waiting on a barrier publishes the CTA
51+
// rows in effectRecipientCTAs, which is the full mask. This is the
52+
// behaviour of TMA multicast. If diagonalEffectRecipientCTAs is true,
53+
// waiting on a barrier publishes only the CTA rows in effectRecipientCTAs,
54+
// which is the diagonal mask. e.g. effectRecipientCTAs = 0b1101 If
55+
// DiagonalEffectRecipientCTAs is false, waiting on the barrier publishes
56+
// the following CTA rows: CTA0 0b1101 CTA1 0b1101 CTA2 0b1101 CTA3 0b1101
57+
// If diagonalEffectRecipientCTAs is true, waiting on the barrier publishes
58+
// the following CTA rows: CTA0 0b1000 CTA1 0b0100 CTA2 0b0000 CTA3 0b0001
59+
bool diagonalEffectRecipientCTAs = false;
3760
};
3861
enum class TrackingKind {
3962
None,
@@ -83,6 +106,8 @@ class ConSanTargetHooks {
83106

84107
virtual bool isTMAOp(Operation *op) const = 0;
85108

109+
virtual bool isCLCOp(Operation *op) const { return false; }
110+
86111
virtual std::optional<BarrierInitInfo>
87112
getBarrierInitInfo(Operation *op) const = 0;
88113

lib/Analysis/BufferRegion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ void BufferRegionAnalysis::calculateUsedBufferRegions(Operation *op) {
323323
bool BufferRegionAnalysis::isMemoryAccessOperation(Operation *op) {
324324
if (isa<ttg::LocalLoadOp, ttg::LocalStoreOp, ttng::TMEMLoadOp,
325325
ttng::TMEMStoreOp, ttng::TMEMCopyOp, ttg::AsyncCopyGlobalToLocalOp,
326-
ttng::TMAOpInterface>(op)) {
326+
ttng::TMAOpInterface, ttng::CLCLoadResultOp>(op)) {
327327
return true;
328328
}
329329
if (isa<ttg::MBarrierOpInterface>(op)) {

lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,13 +1674,13 @@ void FunctionBuilder::createTrackVisibleReadsCall(ImplicitLocOpBuilder &b,
16741674
void FunctionBuilder::createTrackBarrierWriteForBufferCall(
16751675
ImplicitLocOpBuilder &b, Value mbar, Value buf, uint32_t length, Value pred,
16761676
MemType memType, Operation *insertPoint, Value barrierRecipientCTAs,
1677-
Value effectRecipientCTAs) {
1677+
Value effectRecipientCTAs, bool diagonalEffectRecipientCTAs) {
16781678
if (auxData.barriers.empty() || auxData.buffers[(int)memType].empty() ||
16791679
auxData.writeTracking[(int)memType].empty()) {
16801680
return;
16811681
}
16821682
assert(!auxData.barrierWriteRecipients.empty() &&
1683-
"barrier write recipients must exist when tracking TMA writes");
1683+
"barrier write recipients must exist when tracking EffectWrites");
16841684
if (!pred)
16851685
pred = arith::ConstantIntOp::create(b, 1, 1);
16861686
Value barriersVal = auxData.barriers.at(insertPoint).value;
@@ -1717,10 +1717,10 @@ void FunctionBuilder::createTrackBarrierWriteForBufferCall(
17171717
b, "track_barrier_write_for_buffer", args,
17181718
/*assertInfo=*/std::nullopt,
17191719
{barriersType, buffersType, writeTrackingType, barrierWriteRecipientsType,
1720-
(uint64_t)memType},
1721-
[barriersType, buffersType, writeTrackingType,
1722-
barrierWriteRecipientsType](ImplicitLocOpBuilder &fb,
1723-
Block *entryBlock) {
1720+
(uint64_t)memType, (uint64_t)diagonalEffectRecipientCTAs},
1721+
[barriersType, buffersType, writeTrackingType, barrierWriteRecipientsType,
1722+
diagonalEffectRecipientCTAs](ImplicitLocOpBuilder &fb,
1723+
Block *entryBlock) {
17241724
Value mbarOffset = entryBlock->getArgument(0);
17251725
Value mbarLengthVal = entryBlock->getArgument(1);
17261726
Value pred = entryBlock->getArgument(2);
@@ -1747,6 +1747,30 @@ void FunctionBuilder::createTrackBarrierWriteForBufferCall(
17471747
createCmpIntTensorScalar(fb, barriers, barrierDescriptor);
17481748
Value effectRecipientCTAsTensor = triton::SplatOp::create(
17491749
fb, barrierWriteRecipientsType, effectRecipientCTAs);
1750+
if (diagonalEffectRecipientCTAs) {
1751+
// Expand the effect CTA mask diagonally: barrier row i publishes only
1752+
// bit i. This models per-CTA results, while the default replicated
1753+
// mask models TMA multicast where a barrier publishes all result
1754+
// rows.
1755+
auto encoding = cast<ttg::DistributedEncodingTrait>(
1756+
barrierWriteRecipientsType.getEncoding());
1757+
auto rowSliceEncoding =
1758+
tti::getSingleDimSliceEncoding(encoding, /*dim=*/0);
1759+
int numCTAs = barrierWriteRecipientsType.getShape()[0];
1760+
auto rowType = RankedTensorType::get({numCTAs}, fb.getI32Type(),
1761+
rowSliceEncoding);
1762+
Value rowIdx = triton::MakeRangeOp::create(fb, rowType,
1763+
/*start=*/0,
1764+
/*end=*/numCTAs);
1765+
auto indexType =
1766+
cast<RankedTensorType>(barrierWriteRecipientsType.cloneWith(
1767+
std::nullopt, fb.getI32Type()));
1768+
rowIdx = convertAndBroadcast(fb, rowIdx, {0}, indexType);
1769+
Value one = tti::createConstIntTensor(fb, fb.getLoc(), 1, indexType);
1770+
Value rowBit = arith::ShLIOp::create(fb, one, rowIdx);
1771+
effectRecipientCTAsTensor =
1772+
arith::AndIOp::create(fb, effectRecipientCTAsTensor, rowBit);
1773+
}
17501774
Value updatedBarrierWriteRecipients = arith::OrIOp::create(
17511775
fb, barrierWriteRecipients, effectRecipientCTAsTensor);
17521776
updatedBarrierWriteRecipients = arith::SelectOp::create(
@@ -2365,11 +2389,10 @@ void FunctionBuilder::createCopyWriteVisibilityCall(ImplicitLocOpBuilder &b,
23652389
Value zeroTensor =
23662390
tti::createConstIntTensor(fb, fb.getLoc(), 0, writeVisibilityType);
23672391

2368-
constexpr uint64_t fullMask =
2369-
tti::THREADS_BITMASK_SIZE == 64
2370-
? std::numeric_limits<uint64_t>::max()
2371-
: (std::numeric_limits<uint64_t>::max() >>
2372-
(64 - tti::THREADS_BITMASK_SIZE));
2392+
uint64_t fullMask = tti::THREADS_BITMASK_SIZE == 64
2393+
? std::numeric_limits<uint64_t>::max()
2394+
: (std::numeric_limits<uint64_t>::max() >>
2395+
(64 - tti::THREADS_BITMASK_SIZE));
23732396
Value fullMaskVal = arith::ConstantIntOp::create(fb, fullMask, 64);
23742397
Value destMaskElem = adjustIntegerWidth(fb, destMaskVal, elemType);
23752398
Value fullMaskElem = adjustIntegerWidth(fb, fullMaskVal, elemType);

lib/Dialect/TritonInstrument/IR/Utility.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -591,8 +591,7 @@ void AuxDataMap::populateAndPassToWarpSpecialize(
591591
if (numCTAs > 1) {
592592
ClusterBarrierOp::create(b, b.getLoc());
593593
} else {
594-
BarrierOp::create(b, b.getLoc(),
595-
AddrSpace::GlobalRead | AddrSpace::GlobalWrite);
594+
BarrierOp::create(b, b.getLoc(), AddrSpace::Local);
596595
}
597596
lock.insert(entryRegion, {lockVal, lockVal.getType()});
598597
passToWarpSpecialize(entryPoint, lock.at(entryRegion), lock, captureCounter);

lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
// buffers | tensor | <C x B x i64> | Base pointers of all (sub)buffers
2424
// barriers | tensor | <C x K x i64> | Pointers to all individual mbarriers
2525
// barrierStates | scratch | <C x K x i64> | Packed barrier phase (bit 0), arrival counts (bits[1..20] init, [21..40] current), and signed tx-count (bits[41..61]); zero means invalid/uninitialized
26-
// barrierWriteRecipients | scratch | <C x K x i32> | CTA bitsets of write-tracking rows reached by outstanding TMA effects on each barrier
26+
// barrierWriteRecipients | scratch | <C x K x i32> | CTA bitsets of EffectWrites rows published by each barrier
2727
// waiting | scratch | <C x K x i32> | Two bits per thread: waiting flag bit (LSB), stored phase bit (bit 1)
2828
// writeVisibility | scratch | <C x B x i64> | Per-buffer thread-visibility bitmask (bit i => thread i visible)
2929
// readVisibility | scratch | <C x B x T x i64> | Per-buffer, per-thread visibility lanes (row-updated; values are bitmasks)
@@ -120,12 +120,16 @@ int getCurrentThread(Operation *op, const ConSanTargetHooks *hooks) {
120120
thread += TC_THREAD_OFFSET;
121121
return thread;
122122
}
123+
if (hooks->isCLCOp(op)) {
124+
thread += CLC_THREAD_OFFSET;
125+
return thread;
126+
}
123127
return thread;
124128
}
125129

126130
int getBaseThread(int thread) { return thread % NUM_THREADS; }
127131

128-
// Peer threads are the equivalent threads in the TMA, TC and normal
132+
// Peer threads are the equivalent threads in the TMA, TC, CLC and normal
129133
// thread classes.
130134
// If a thread is a base thread, return the mask with the peers, otherwise
131135
// return the mask with the thread itself.
@@ -134,6 +138,7 @@ uint64_t getThreadPeersMask(int thread) {
134138
if (thread < NUM_THREADS) {
135139
mask |= 1ULL << (thread + TMA_THREAD_OFFSET);
136140
mask |= 1ULL << (thread + TC_THREAD_OFFSET);
141+
mask |= 1ULL << (thread + CLC_THREAD_OFFSET);
137142
}
138143
return mask;
139144
}
@@ -159,6 +164,12 @@ Value currentCTAMask(ImplicitLocOpBuilder &b) {
159164
ctaId);
160165
}
161166

167+
Value allCTAsMask(ImplicitLocOpBuilder &b) {
168+
int numCTAs = ttg::lookupNumCTAs(b);
169+
assert(numCTAs <= 16 && "ConSan CTA bitsets assume at most 16 CTAs");
170+
return arith::ConstantIntOp::create(b, (1u << numCTAs) - 1, 32);
171+
}
172+
162173
uint16_t getBlockBroadcastMask(Value alloc) {
163174
auto allocTy = cast<ttg::MemDescType>(alloc.getType());
164175
auto kBlock = StringAttr::get(alloc.getContext(), "block");
@@ -259,6 +270,8 @@ Value getMemEffectRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op) {
259270
return getMulticastRecipientCTAs(b, tmaLoad.getResult());
260271
return currentCTAMask(b);
261272
}
273+
if (isa<ttng::CLCTryCancelOp>(op))
274+
return allCTAsMask(b);
262275
if (isTensorCoreOp(op))
263276
return getRecipientCTAsForBroadcastMasks(
264277
b, ttng::getCTABroadcastMasks(ttng::getModuleTwoCTAs(op), {}));
@@ -278,6 +291,8 @@ Value getBarrierRecipientCTAs(ImplicitLocOpBuilder &b, Operation *op) {
278291
tmaLoad.getBarrier());
279292
return getLeaderCTA(b, tmaLoad.getBarrier());
280293
}
294+
if (isa<ttng::CLCTryCancelOp>(op))
295+
return allCTAsMask(b);
281296

282297
if (isTensorCoreOp(op))
283298
return getRecipientCTAsForBroadcastMasks(
@@ -560,7 +575,8 @@ class ConcurrencySanitizerImpl {
560575
memType = MemType::SHARED_MEM;
561576
funcBuilder.createTrackBarrierWriteForBufferCall(
562577
b, barrier, effect.buf, effect.length, combinedPred, memType, op,
563-
recipientCTAs, effectRecipientCTAs);
578+
recipientCTAs, effectRecipientCTAs,
579+
barrierInfo.diagonalEffectRecipientCTAs);
564580
}
565581
}
566582
if (barrierInfo.count > 0 || barrierInfo.txCount != 0) {

lib/Dialect/TritonNvidiaGPU/Transforms/ClusterBarrierInsertion.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,11 @@ usesTrackedBarrierInCrossCTAConsumerOp(Operation *op,
108108
if (auto commit = dyn_cast<ttng::TCGen5CommitOp>(op)) {
109109
return ttng::getModuleTwoCTAs(op) && aliasesTracked(commit.getBarrier());
110110
}
111-
if (auto tma = dyn_cast<ttng::AsyncTMACopyGlobalToLocalOp>(op)) {
111+
if (auto tma = dyn_cast<ttng::TMALoadLikeOpInterface>(op)) {
112112
return tma.getMulticast() && aliasesTracked(tma.getBarrier());
113113
}
114-
if (auto tma = dyn_cast<ttng::AsyncTMAGatherOp>(op)) {
115-
return tma.getMulticast() && aliasesTracked(tma.getBarrier());
114+
if (auto clc = dyn_cast<ttng::CLCTryCancelOp>(op)) {
115+
return aliasesTracked(clc.getMbarrier());
116116
}
117117
return false;
118118
}

lib/Dialect/TritonNvidiaGPU/Transforms/ConSanNVIDIA.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks {
4242
return isa<ttng::TMAOpInterface>(op);
4343
}
4444

45+
bool isCLCOp(Operation *op) const override {
46+
return isa<ttng::CLCTryCancelOp>(op);
47+
}
48+
4549
std::optional<BarrierInitInfo>
4650
getBarrierInitInfo(Operation *op) const override {
4751
if (auto initOp = dyn_cast<ttng::InitBarrierOp>(op)) {
@@ -101,6 +105,11 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks {
101105
}
102106
if (auto storeOp = dyn_cast<ttng::TMAStoreLikeOpInterface>(op))
103107
mask = getBlockBroadcastMask(storeOp.getSrc().getType());
108+
if (isa<ttng::CLCTryCancelOp>(op) && ttg::lookupNumCTAs(op) > 1) {
109+
Value ctaId = tti::ExperimentalClusterCTAIdOp::create(b, b.getLoc());
110+
return arith::CmpIOp::create(b, arith::CmpIPredicate::eq, ctaId,
111+
arith::ConstantIntOp::create(b, 0, 32));
112+
}
104113

105114
// In 2CTA tcgen05 and tmem_copy, only the even CTA in each (i, i^1) pair
106115
// issues the op.
@@ -226,6 +235,24 @@ class NVIDIAConSanHooks : public tti::ConSanTargetHooks {
226235
info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write,
227236
loadOp.getResult());
228237
}
238+
if (auto tryCancelOp = dyn_cast<ttng::CLCTryCancelOp>(op)) {
239+
info.emplace();
240+
info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier;
241+
info->barriers.push_back(
242+
{tryCancelOp.getMbarrier(), nullptr, /*count=*/0,
243+
MemEffectsOpInfo::BarrierTrackingMode::EffectWrites,
244+
/*txCount=*/
245+
-static_cast<int>(tti::getMemDescLength(tryCancelOp.getResult())),
246+
/*diagonalEffectRecipientCTAs=*/true});
247+
info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Write,
248+
tryCancelOp.getResult());
249+
}
250+
if (auto loadResultOp = dyn_cast<ttng::CLCLoadResultOp>(op)) {
251+
info.emplace();
252+
info->trackingKind = MemEffectsOpInfo::TrackingKind::Barrier;
253+
info->operandEffects.emplace_back(MemEffectsOpInfo::Effects::Read,
254+
loadResultOp.getSrc());
255+
}
229256
if (auto storeOp = dyn_cast<ttng::TMAStoreLikeOpInterface>(op)) {
230257
info.emplace();
231258
info->trackingKind = MemEffectsOpInfo::TrackingKind::CommitCount;

0 commit comments

Comments
 (0)