Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
45e2e4b
Accelerate FPSan MMA emulation with i8 decomposition
jeffniu-openai Jun 3, 2026
a784db0
Test FPSan TCGen MMA in warp partitions
jeffniu-openai Jun 3, 2026
333348b
[FPSan] Address i8 decomposition review comments
jeffniu-openai Jun 4, 2026
4dd75e4
Support multi-CTA local gather and scatter
jeffniu-openai Jun 4, 2026
80bf7cc
Simplify multi-CTA gather and scatter lowering
jeffniu-openai Jun 4, 2026
1ffc0a1
Preserve explicit cluster gather codegen
jeffniu-openai Jun 4, 2026
6dc53ac
Apply pre-commit formatting
jeffniu-openai Jun 4, 2026
4869637
Add instrumentation local gather for FPSan
jeffniu-openai Jun 4, 2026
2c9405f
Simplify instrumentation local gather
jeffniu-openai Jun 4, 2026
b1d02d7
Apply pre-commit formatting
jeffniu-openai Jun 4, 2026
8cc70b2
Apply post-restack formatting
jeffniu-openai Jun 5, 2026
5194768
Merge remote-tracking branch 'refs/remotes/github/main' into jeffniu/…
jeffniu-openai Jun 5, 2026
abde891
[NVIDIA] Address multi-CTA gather review
jeffniu-openai Jun 5, 2026
68037d4
[NVIDIA] Minimize multi-CTA shared dispatch
jeffniu-openai Jun 5, 2026
3d0f65f
[NVIDIA] Trim multi-CTA gather changes
jeffniu-openai Jun 6, 2026
d098a0e
[NVIDIA] Restore multi-CTA lowering coverage
jeffniu-openai Jun 6, 2026
96e29f0
[NVIDIA] Simplify multi-CTA runtime test setup
jeffniu-openai Jun 6, 2026
1e55b3c
[NVIDIA] Use nullable values for distributed shared memory
jeffniu-openai Jun 6, 2026
d9bd4b7
merge
jeffniu-openai Jun 6, 2026
2396557
Merge branch 'jeffniu/local-gather-scatter-multicta' of https://githu…
jeffniu-openai Jun 6, 2026
ac28f5d
cleanup
jeffniu-openai Jun 6, 2026
ac6263d
[NVIDIA] Always map distributed shared accesses
jeffniu-openai Jun 6, 2026
32f914d
[GPUToLLVM] Lookup local address outputs by name
jeffniu-openai Jun 6, 2026
a54b8e9
[NVIDIA] Relax local gather barrier check
jeffniu-openai Jun 7, 2026
87bbf02
Merge branch 'jeffniu/local-gather-scatter-multicta' into jeffniu/tti…
jeffniu-openai Jun 7, 2026
967ee8f
Merge triton-lang/triton main into jeffniu/tti-experimental-local-gather
jeffniu-openai Jun 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ bool supportWMMA(triton::DotOp op);

bool supportMMA(triton::DotOp op, int version);

bool supportMMA(triton::DotOpInterface op, int version);

bool supportMMA(Value value, int version);

// Conversion from `srcTy` to `dstTy` involving the minimum amount of data
Expand Down
21 changes: 19 additions & 2 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -605,16 +605,33 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
const LinearLayout &layout, RankedTensorType type,
bool withCTAOffset);

// Compute per-element shared-memory pointers for a local atomic/ldst update by
struct LocalSharedMemoryAddress {
Value ptr;
std::optional<Value> ctaId;
};

// Compute per-element shared-memory addresses for a local atomic/ldst update by
// replacing `coords[*][axis]` with `idxValues[*]` and mapping the resulting
// logical coordinates back to shared-memory offsets.
// logical coordinates back to shared-memory offsets and target CTAs.
SmallVector<LocalSharedMemoryAddress>
computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy,
SharedMemoryObject smemObj, Type llvmElemTy,
ArrayRef<Value> idxValues,
ArrayRef<SmallVector<Value>> coords, unsigned axis,
RewriterBase &rewriter, ArrayRef<Value> offsets = {});

SmallVector<Value> computeLocalPtrs(Location loc,
triton::gpu::MemDescType memDescTy,
SharedMemoryObject smemObj, Type llvmElemTy,
ArrayRef<Value> idxValues,
ArrayRef<SmallVector<Value>> coords,
unsigned axis, RewriterBase &rewriter);

SmallVector<Value> loadLocalAddrs(Location loc, Type llvmElemTy,
ArrayRef<LocalSharedMemoryAddress> addrs,
RewriterBase &rewriter,
const TargetInfoBase &targetInfo);

// Backend-agnostic preparation for lowering LocalAtomicScatterRMWOp.
struct LocalAtomicScatterRMWInfo {
RankedTensorType valuesTy;
Expand Down
8 changes: 7 additions & 1 deletion include/triton/Dialect/Triton/IR/TritonOpInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
/*desc=*/"Verify the dimensions of the A and B DotOp operands.",
/*retType=*/"bool",
/*methodName=*/"verifyDims",
/*args=*/(ins)>,
/*args=*/(ins),
/*methodBody=*/[{}],
/*defaultImpl=*/ [{
auto aShape = cast<ShapedType>($_op.getA().getType()).getShape();
auto bShape = cast<ShapedType>($_op.getB().getType()).getShape();
return aShape.back() == bShape[bShape.size() - 2];
}]>,
InterfaceMethod<
/*desc=*/"Verify the dimensions of the DotOp output.",
/*retType=*/"bool",
Expand Down
6 changes: 6 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,12 @@ SmallVector<int64_t> getAllocationShapePerCTA(Type type);

unsigned getNumCTAs(Attribute layout);

// Returns the MMAv2 warp distribution for a matrix tile. This does not apply
// dot-chain policy and may oversubscribe tiles with fewer instruction
// repetitions than warps.
SmallVector<unsigned> getMmaV2WarpsPerCTA(ArrayRef<int64_t> shape,
int numWarps);

// Return the order that represents that the batch is in row-major or
// column-major order for a batch of matrices of shape [*, m, n] with
// len(shape) == rank.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ include "triton/Dialect/TritonInstrument/IR/TritonInstrumentDialect.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand All @@ -14,6 +15,7 @@ include "triton/Dialect/TritonInstrument/IR/TritonInstrumentAttrDefs.td"
// Interfaces
//
def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">;
def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">;

//
// Ops
Expand Down Expand Up @@ -77,6 +79,29 @@ def TTI_ExperimentalClusterCTAIdOp
let assemblyFormat = "attr-dict `:` type($result)";
}

def TTI_ExperimentalLocalGatherOp

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't just slice and local_load?

Comment thread
peterbell10 marked this conversation as resolved.
: TTI_Op<"experimental_local_gather"> {
let summary = "Gather elements from shared memory with logical base offsets";
let description = [{
Gather elements from a shared memory descriptor using an index tensor along
one axis, after shifting the logical source coordinates by rank-sized scalar
offsets. This is intentionally private to instrumentation passes.
}];
let arguments = (ins
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>]>:$src,
TT_IntTensor:$indices,
Variadic<I32>:$offsets,
I32Attr:$axis
);
let results = (outs TT_Tensor:$result);

let assemblyFormat = [{
$src `[` $indices `]` `offsets` `=` `[` $offsets `]`
attr-dict `:` qualified(type($src)) `,` type($indices) `->` type($result)
}];
let hasVerifier = 1;
}

def TTI_ExperimentalGSanInitOp
: TTI_Op<"experimental_gsan_init"> {
let summary = "Initialize GSan thread";
Expand Down Expand Up @@ -210,6 +235,32 @@ def TTI_ExperimentalLockReleaseOp : TTI_Op<"experimental_lock_release", [MemoryE
// ===== FPSan ops =====


def TTI_DotI8Op : TTI_Op<"dot_i8", [
Pure,
DeclareOpInterfaceMethods<DotOpInterface>,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">
]> {
let summary = "non-saturating NVIDIA MMAv2 i8 dot";
let description = [{
Performs a wrapping i8 matrix multiplication into an i32 accumulator using
NVIDIA MMAv2. The A and B operands have independent signedness.
}];
let arguments = (ins
RankedTensorOf<[I8]>:$a,
RankedTensorOf<[I8]>:$b,
RankedTensorOf<[I32]>:$c,
BoolAttr:$aSigned,
BoolAttr:$bSigned
);
let results = (outs RankedTensorOf<[I32]>:$d);
let assemblyFormat = [{
$a `,` $b `,` $c `,` `aSigned` `=` $aSigned `,` `bSigned` `=` $bSigned
attr-dict `:` type($a) `*` type($b) `->` type($d)
}];
let hasVerifier = 1;
}

def TTI_ExperimentalFPSanEmbedOp : TTI_Op<"experimental_fpsan_embed", [
Pure,
Elementwise,
Expand Down
6 changes: 6 additions & 0 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,12 @@ bool supportMMA(triton::DotOp op, int version) {
return supportMMA(op.getA(), version) && supportMMA(op.getB(), version);
}

bool supportMMA(triton::DotOpInterface op, int version) {
if (auto dotOp = dyn_cast<triton::DotOp>(op.getOperation()))
return supportMMA(dotOp, version);
return supportMMA(op.getA(), version) && supportMMA(op.getB(), version);
}

bool supportMMA(Value value, int version) {
// Tell whether a DotOp support MMA by the operand type(either $a or $b).
// We cannot get both the operand types(in TypeConverter), here we assume the
Expand Down
40 changes: 16 additions & 24 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,26 @@ lowerLocalScGt(Location loc, MLIRContext *ctx, MemDescType memDescTy,
unsigned axis, ArrayRef<Value> storeVals, RewriterBase &rewriter,
const TargetInfoBase &targetInfo) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
bool isScatter = !storeVals.empty();
SmallVector<Value> ptrs = computeLocalPtrs(
SmallVector<LocalSharedMemoryAddress> addrs = computeLocalAddrs(
loc, memDescTy, smemObj, llvmElemTy, idxValues, coords, axis, rewriter);
if (storeVals.empty())
return loadLocalAddrs(loc, llvmElemTy, addrs, rewriter, targetInfo);

SmallVector<Value> results;
if (!isScatter)
results.resize(coords.size());
Value currentCtaId;
if (!addrs.empty() && addrs.front().ctaId)
currentCtaId = targetInfo.getClusterCTAId(rewriter, loc);

for (auto [i, ptr] : llvm::enumerate(ptrs)) {
if (isScatter) {
targetInfo.storeShared(rewriter, loc, ptr, storeVals[i], b.true_val());
SmallVector<Value> results;
for (auto [i, addr] : llvm::enumerate(addrs)) {
if (addr.ctaId) {
Value isLocal = b.icmp_eq(*addr.ctaId, currentCtaId);
Value isRemote = b.icmp_ne(*addr.ctaId, currentCtaId);
targetInfo.storeShared(rewriter, loc, addr.ptr, storeVals[i], isLocal);
targetInfo.storeDShared(rewriter, loc, addr.ptr, addr.ctaId, storeVals[i],
isRemote);
} else {
results[i] =
targetInfo.loadShared(rewriter, loc, ptr, llvmElemTy, b.true_val());
targetInfo.storeShared(rewriter, loc, addr.ptr, storeVals[i],
b.true_val());
}
}

Expand Down Expand Up @@ -267,13 +273,6 @@ struct LocalGatherOpConversion : public ConvertOpToLLVMPattern<LocalGatherOp> {
auto loc = op.getLoc();
auto *ctx = op.getContext();
auto memDescTy = cast<MemDescType>(op.getSrc().getType());
// TODO: PartitionedSharedEncoding lowering will be enabled in subsequent
// PRs.
if (isa<triton::gpu::PartitionedSharedEncodingAttr>(
memDescTy.getEncoding())) {
return rewriter.notifyMatchFailure(
op, "PartitionedSharedEncoding not yet supported in lowering");
}
auto regTy = cast<RankedTensorType>(op.getType());
auto typeConverter = getTypeConverter();

Expand Down Expand Up @@ -316,13 +315,6 @@ struct LocalScatterOpConversion
auto loc = op.getLoc();
auto *ctx = op.getContext();
auto memDescTy = cast<MemDescType>(op.getDst().getType());
// TODO: PartitionedSharedEncoding lowering will be enabled in subsequent
// PRs.
if (isa<triton::gpu::PartitionedSharedEncodingAttr>(
memDescTy.getEncoding())) {
return rewriter.notifyMatchFailure(
op, "PartitionedSharedEncoding not yet supported in lowering");
}
auto valuesTy = cast<RankedTensorType>(op.getValues().getType());
auto typeConverter = getTypeConverter();

Expand Down
76 changes: 61 additions & 15 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,12 +540,12 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
return emitIndices(loc, rewriter, target, ll, type, withCTAOffset);
}

SmallVector<Value> computeLocalPtrs(Location loc,
triton::gpu::MemDescType memDescTy,
SharedMemoryObject smemObj, Type llvmElemTy,
ArrayRef<Value> idxValues,
ArrayRef<SmallVector<Value>> coords,
unsigned axis, RewriterBase &rewriter) {
SmallVector<LocalSharedMemoryAddress>
computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy,
SharedMemoryObject smemObj, Type llvmElemTy,
ArrayRef<Value> idxValues,
ArrayRef<SmallVector<Value>> coords, unsigned axis,
RewriterBase &rewriter, ArrayRef<Value> offsets) {
MLIRContext *ctx = memDescTy.getContext();
auto b = TritonLLVMOpBuilder(loc, rewriter);

Expand All @@ -561,12 +561,15 @@ SmallVector<Value> computeLocalPtrs(Location loc,
allDims.push_back(str_attr("dim" + Twine(dim)));

auto kOffset = str_attr("offset");
auto kBlock = str_attr("block");
bool useBlockId = invSharedLayout.hasOutDim(kBlock) &&
invSharedLayout.getOutDimSize(kBlock) > 1;
// Get the subslice affine offset (non-zero for memdesc subslices)
Value affineOffset = smemObj.getShmemOffset(loc, rewriter, memDescTy);
auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy);

SmallVector<Value> ptrs;
ptrs.reserve(coords.size());
SmallVector<LocalSharedMemoryAddress> addrs;
addrs.reserve(coords.size());

for (auto [i, idxVal] : llvm::enumerate(idxValues)) {
Value idx = idxVal;
Expand All @@ -578,9 +581,12 @@ SmallVector<Value> computeLocalPtrs(Location loc,
idx = b.zext(i32_ty, idx);
}

// Copy coordinates and replace the axis coordinate with the index value
// Copy coordinates, replace the axis coordinate with the index value, and
// then shift all logical coordinates by the optional base offsets.
SmallVector<Value> indices(coords[i]);
indices[axis] = idx;
for (auto [dim, offset] : llvm::enumerate(offsets))
indices[dim] = b.add(indices[dim], offset);

// Apply inverted shared layout to compute offset
SmallVector<std::pair<StringAttr, Value>> inputs;
Expand All @@ -589,15 +595,18 @@ SmallVector<Value> computeLocalPtrs(Location loc,

auto outputs = applyLinearLayout(loc, rewriter, invSharedLayout, inputs);

// Extract the offset value
// Extract the offset and target CTA.
Value offset = nullptr;
Value blockId = nullptr;
for (auto [name, value] : outputs) {
if (name == kOffset) {
if (name == kOffset)
offset = value;
break;
}
else if (name == kBlock)
blockId = value;
}
assert(offset && "expected offset output from inverted shared layout");
assert((!useBlockId || blockId) &&
"expected block output from multi-CTA shared layout");

// For subslices, the physical offset is computed as:
// physical_offset = L⁻¹(coords) ⊕ L⁻¹(subslice_logical_offset)
Expand Down Expand Up @@ -626,10 +635,47 @@ SmallVector<Value> computeLocalPtrs(Location loc,
ptr = b.gep(smemObj.getBase().getType(), llvmElemTy, smemObj.getBase(),
offset);
}
ptrs.push_back(ptr);
addrs.push_back(
{ptr, useBlockId ? std::optional<Value>(blockId) : std::nullopt});
}

return ptrs;
return addrs;
}

SmallVector<Value> computeLocalPtrs(Location loc,
triton::gpu::MemDescType memDescTy,
SharedMemoryObject smemObj, Type llvmElemTy,
ArrayRef<Value> idxValues,
ArrayRef<SmallVector<Value>> coords,
unsigned axis, RewriterBase &rewriter) {
return llvm::map_to_vector(
computeLocalAddrs(loc, memDescTy, smemObj, llvmElemTy, idxValues, coords,
axis, rewriter),
[](const LocalSharedMemoryAddress &addr) { return addr.ptr; });
}

SmallVector<Value> loadLocalAddrs(Location loc, Type llvmElemTy,
ArrayRef<LocalSharedMemoryAddress> addrs,
RewriterBase &rewriter,
const TargetInfoBase &targetInfo) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
Value currentCtaId;
if (!addrs.empty() && addrs.front().ctaId)
currentCtaId = targetInfo.getClusterCTAId(rewriter, loc);

return llvm::map_to_vector(
addrs, [&](const LocalSharedMemoryAddress &addr) -> Value {
if (!addr.ctaId)
return targetInfo.loadShared(rewriter, loc, addr.ptr, llvmElemTy,
b.true_val());
Value isLocal = b.icmp_eq(*addr.ctaId, currentCtaId);
Value local =
targetInfo.loadShared(rewriter, loc, addr.ptr, llvmElemTy, isLocal);
Value remote = targetInfo.loadDShared(
rewriter, loc, addr.ptr, addr.ctaId, llvmElemTy,
b.icmp_ne(*addr.ctaId, currentCtaId));
return b.select(isLocal, local, remote);
});
}

FailureOr<LocalAtomicScatterRMWInfo> prepareLocalAtomicScatterRMW(
Expand Down
Loading
Loading