Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
45e2e4b
Accelerate FPSan MMA emulation with i8 decomposition
jeffniu-openai Jun 3, 2026
a784db0
Test FPSan TCGen MMA in warp partitions
jeffniu-openai Jun 3, 2026
333348b
[FPSan] Address i8 decomposition review comments
jeffniu-openai Jun 4, 2026
4dd75e4
Support multi-CTA local gather and scatter
jeffniu-openai Jun 4, 2026
80bf7cc
Simplify multi-CTA gather and scatter lowering
jeffniu-openai Jun 4, 2026
1ffc0a1
Preserve explicit cluster gather codegen
jeffniu-openai Jun 4, 2026
6dc53ac
Apply pre-commit formatting
jeffniu-openai Jun 4, 2026
4869637
Add instrumentation local gather for FPSan
jeffniu-openai Jun 4, 2026
2c9405f
Simplify instrumentation local gather
jeffniu-openai Jun 4, 2026
b1d02d7
Apply pre-commit formatting
jeffniu-openai Jun 4, 2026
8cc70b2
Apply post-restack formatting
jeffniu-openai Jun 5, 2026
5194768
Merge remote-tracking branch 'refs/remotes/github/main' into jeffniu/…
jeffniu-openai Jun 5, 2026
abde891
[NVIDIA] Address multi-CTA gather review
jeffniu-openai Jun 5, 2026
68037d4
[NVIDIA] Minimize multi-CTA shared dispatch
jeffniu-openai Jun 5, 2026
3d0f65f
[NVIDIA] Trim multi-CTA gather changes
jeffniu-openai Jun 6, 2026
d098a0e
[NVIDIA] Restore multi-CTA lowering coverage
jeffniu-openai Jun 6, 2026
96e29f0
[NVIDIA] Simplify multi-CTA runtime test setup
jeffniu-openai Jun 6, 2026
1e55b3c
[NVIDIA] Use nullable values for distributed shared memory
jeffniu-openai Jun 6, 2026
d9bd4b7
merge
jeffniu-openai Jun 6, 2026
2396557
Merge branch 'jeffniu/local-gather-scatter-multicta' of https://githu…
jeffniu-openai Jun 6, 2026
ac28f5d
cleanup
jeffniu-openai Jun 6, 2026
ac6263d
[NVIDIA] Always map distributed shared accesses
jeffniu-openai Jun 6, 2026
32f914d
[GPUToLLVM] Lookup local address outputs by name
jeffniu-openai Jun 6, 2026
a54b8e9
[NVIDIA] Relax local gather barrier check
jeffniu-openai Jun 7, 2026
87bbf02
Merge branch 'jeffniu/local-gather-scatter-multicta' into jeffniu/tti…
jeffniu-openai Jun 7, 2026
967ee8f
Merge triton-lang/triton main into jeffniu/tti-experimental-local-gather
jeffniu-openai Jun 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -613,10 +613,12 @@ struct LocalSharedMemoryAddress {
// Compute per-element shared-memory addresses for a local atomic/ldst update by
// replacing `coords[*][axis]` with `idxValues[*]` and mapping the resulting
// logical coordinates back to shared-memory offsets and target CTAs.
SmallVector<LocalSharedMemoryAddress> computeLocalAddrs(
Location loc, triton::gpu::MemDescType memDescTy,
SharedMemoryObject smemObj, Type llvmElemTy, ArrayRef<Value> idxValues,
ArrayRef<SmallVector<Value>> coords, unsigned axis, RewriterBase &rewriter);
SmallVector<LocalSharedMemoryAddress>
computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy,
SharedMemoryObject smemObj, Type llvmElemTy,
ArrayRef<Value> idxValues,
ArrayRef<SmallVector<Value>> coords, unsigned axis,
RewriterBase &rewriter, ArrayRef<Value> offsets = {});

// Backend-agnostic preparation for lowering LocalAtomicScatterRMWOp.
struct LocalAtomicScatterRMWInfo {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ include "triton/Dialect/TritonInstrument/IR/TritonInstrumentAttrDefs.td"
// Interfaces
//
def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">;
def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">;

//
// Ops
Expand Down Expand Up @@ -78,6 +79,29 @@ def TTI_ExperimentalClusterCTAIdOp
let assemblyFormat = "attr-dict `:` type($result)";
}

def TTI_ExperimentalLocalGatherOp

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't just slice and local_load?

Comment thread
peterbell10 marked this conversation as resolved.
: TTI_Op<"experimental_local_gather"> {
let summary = "Gather elements from shared memory with logical base offsets";
let description = [{
Gather elements from a shared memory descriptor using an index tensor along
one axis, after shifting the logical source coordinates by rank-sized scalar
offsets. This is intentionally private to instrumentation passes.
}];
let arguments = (ins
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>]>:$src,
TT_IntTensor:$indices,
Variadic<I32>:$offsets,
I32Attr:$axis
);
let results = (outs TT_Tensor:$result);

let assemblyFormat = [{
$src `[` $indices `]` `offsets` `=` `[` $offsets `]`
attr-dict `:` qualified(type($src)) `,` type($indices) `->` type($result)
}];
let hasVerifier = 1;
}

def TTI_ExperimentalGSanInitOp
: TTI_Op<"experimental_gsan_init"> {
let summary = "Initialize GSan thread";
Expand Down
7 changes: 5 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy,
SharedMemoryObject smemObj, Type llvmElemTy,
ArrayRef<Value> idxValues,
ArrayRef<SmallVector<Value>> coords, unsigned axis,
RewriterBase &rewriter) {
RewriterBase &rewriter, ArrayRef<Value> offsets) {
MLIRContext *ctx = memDescTy.getContext();
auto b = TritonLLVMOpBuilder(loc, rewriter);

Expand Down Expand Up @@ -580,9 +580,12 @@ computeLocalAddrs(Location loc, triton::gpu::MemDescType memDescTy,
idx = b.zext(i32_ty, idx);
}

// Copy coordinates and replace the axis coordinate with the index value
// Copy coordinates, replace the axis coordinate with the index value, and
// then shift all logical coordinates by the optional base offsets.
SmallVector<Value> indices(coords[i]);
indices[axis] = idx;
for (auto [dim, offset] : llvm::enumerate(offsets))
indices[dim] = b.add(indices[dim], offset);

// Apply inverted shared layout to compute offset
SmallVector<std::pair<StringAttr, Value>> inputs;
Expand Down
46 changes: 46 additions & 0 deletions lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,51 @@ struct ClusterCTAIdOpConversion
const TargetInfoBase &targetInfo;
};

struct LocalGatherOpConversion
: public ConvertOpToLLVMPattern<tti::ExperimentalLocalGatherOp> {
LocalGatherOpConversion(const LLVMTypeConverter &converter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<tti::ExperimentalLocalGatherOp>(converter,
benefit),
targetInfo(targetInfo) {}

LogicalResult
matchAndRewrite(tti::ExperimentalLocalGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto memDescTy = cast<ttg::MemDescType>(op.getSrc().getType());
auto regTy = cast<RankedTensorType>(op.getType());
auto typeConverter = getTypeConverter();

Type llvmElemTy = typeConverter->convertType(memDescTy.getElementType());
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
llvmElemTy, rewriter);
auto idxValues = unpackLLElements(loc, adaptor.getIndices(), rewriter);
auto dstIndices =
emitIndices(loc, rewriter, targetInfo, regTy.getEncoding(), regTy,
/*withCTAOffset=*/true);
SmallVector<Value> offsets(adaptor.getOffsets());

auto addrs =
computeLocalAddrs(loc, memDescTy, smemObj, llvmElemTy, idxValues,
dstIndices, op.getAxis(), rewriter, offsets);
auto b = TritonLLVMOpBuilder(loc, rewriter);
SmallVector<Value> results =
llvm::map_to_vector(addrs, [&](const LocalSharedMemoryAddress &addr) {
return targetInfo.loadDShared(rewriter, loc, addr.ptr, addr.ctaId,
llvmElemTy, b.true_val());
});
Value result = packLLElements(loc, typeConverter, results, rewriter, regTy);

rewriter.replaceOp(op, result);
return success();
}

private:
const TargetInfoBase &targetInfo;
};

} // namespace

void mlir::triton::populateInstrumentationToLLVMPatterns(
Expand All @@ -358,4 +403,5 @@ void mlir::triton::populateInstrumentationToLLVMPatterns(
patterns.add<LockReleaseOpConversion>(typeConverter, targetInfo);
patterns.add<MemDescToI32OpConversion>(typeConverter);
patterns.add<ClusterCTAIdOpConversion>(typeConverter, targetInfo);
patterns.add<LocalGatherOpConversion>(typeConverter, targetInfo);
}
34 changes: 34 additions & 0 deletions lib/Dialect/TritonInstrument/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,40 @@ LogicalResult DotI8Op::verify() {
bEnc);
}

LogicalResult ExperimentalLocalGatherOp::verify() {
auto srcTy = getSrc().getType();
auto indicesTy = cast<RankedTensorType>(getIndices().getType());
auto dstTy = cast<RankedTensorType>(getType());
unsigned axis = getAxis();

if (!isa<ttg::SharedEncodingTrait>(srcTy.getEncoding()))
return emitError("source must have shared memory encoding");

if (!indicesTy.getElementType().isInteger())
return emitError("indices must have integer element type");

if (dstTy.getShape() != indicesTy.getShape())
return emitError("result shape must match indices shape");

if (srcTy.getRank() != indicesTy.getRank())
return emitError("source and indices must have the same rank");

if (axis >= srcTy.getRank())
return emitError("axis ")
<< axis << " is out of bounds for source rank " << srcTy.getRank();

if (srcTy.getElementType() != dstTy.getElementType())
return emitError("result element type must match source element type");

if (indicesTy.getEncoding() != dstTy.getEncoding())
return emitError("indices and result must have the same layout");

if (static_cast<int64_t>(getOffsets().size()) != srcTy.getRank())
return emitError("offset count must match source rank");

return success();
}

template <typename ViewOp, typename FPSanOp>
struct PushFPSanThroughViewPattern : public OpRewritePattern<ViewOp> {
using OpRewritePattern<ViewOp>::OpRewritePattern;
Expand Down
Loading
Loading