Skip to content

Commit 89d154d

Browse files
authored
[Canonicalize] Transform ptr_to_int->add->int_to_ptr to addptr (#9971)
Add canonicalization pattern for IntToPtrOp that recognizes the pattern: int_to_ptr(addi(ptr_to_int(ptr), constant_offset)) and transforms it to: addptr(ptr, element_offset) where element_offset = constant_offset / element_size_bytes. This pattern appears when performing pointer arithmetic via integer operations (e.g., adding byte offsets to pointers). By canonicalizing to addptr, AxisInfoAnalysis can correctly track contiguity, enabling proper vectorization for operations like async_copy_local_to_global. The pattern only applies when: - The offset is a compile-time constant (IntegerAttr or SplatElementsAttr) - The byte offset is evenly divisible by the element size Added to both standard canonicalize and gluon-canonicalize passes. Tests added for positive cases (f32, f16, commutative) and negative cases (non-constant offset, indivisible offset).
1 parent 486d972 commit 89d154d

4 files changed

Lines changed: 183 additions & 0 deletions

File tree

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise,
4848
let results = (outs TT_PtrLike:$result);
4949

5050
let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
51+
52+
let hasCanonicalizer = 1;
5153
}
5254

5355
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise,

lib/Dialect/Gluon/Transforms/Canonicalize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ void Canonicalize::runOnOperation() {
5757
StoreOp::getCanonicalizationPatterns(patterns, ctx);
5858
BroadcastOp::getCanonicalizationPatterns(patterns, ctx);
5959
ExpandDimsOp::getCanonicalizationPatterns(patterns, ctx);
60+
IntToPtrOp::getCanonicalizationPatterns(patterns, ctx);
6061
ttg::WarpSpecializeOp::getCanonicalizationPatterns(patterns, ctx);
6162
ttg::WarpSpecializePartitionsOp::getCanonicalizationPatterns(patterns, ctx);
6263

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,6 +1095,113 @@ void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state,
10951095
return build(builder, state, descTy, base, shape, strides, paddingAttr);
10961096
}
10971097

1098+
//-- IntToPtrOp --
1099+
// Pattern 1: int_to_ptr(ptr_to_int(ptr)) -> ptr
1100+
// Eliminates round-trip pointer conversions
1101+
struct CanonicalizeIntToPtrOfPtrToInt : public OpRewritePattern<IntToPtrOp> {
1102+
CanonicalizeIntToPtrOfPtrToInt(MLIRContext *context)
1103+
: OpRewritePattern<IntToPtrOp>(context, 1) {}
1104+
1105+
LogicalResult matchAndRewrite(IntToPtrOp intToPtrOp,
1106+
PatternRewriter &rewriter) const override {
1107+
// Match: int_to_ptr(ptr_to_int(ptr))
1108+
auto ptrToIntOp = intToPtrOp.getSrc().getDefiningOp<PtrToIntOp>();
1109+
if (!ptrToIntOp)
1110+
return failure();
1111+
1112+
// Replace with the original pointer
1113+
rewriter.replaceOp(intToPtrOp, ptrToIntOp.getSrc());
1114+
return success();
1115+
}
1116+
};
1117+
1118+
// Pattern 2: int_to_ptr(addi(val, constant_offset)) -> addptr(int_to_ptr(val),
1119+
// element_offset). Only when offset is constant and divisible by element size
1120+
struct CanonicalizeIntToPtrWithAdd : public OpRewritePattern<IntToPtrOp> {
1121+
CanonicalizeIntToPtrWithAdd(MLIRContext *context)
1122+
: OpRewritePattern<IntToPtrOp>(context, 1) {}
1123+
1124+
LogicalResult matchAndRewrite(IntToPtrOp intToPtrOp,
1125+
PatternRewriter &rewriter) const override {
1126+
// Match: int_to_ptr(addi(val, constant_offset))
1127+
auto addOp = intToPtrOp.getSrc().getDefiningOp<arith::AddIOp>();
1128+
if (!addOp)
1129+
return failure();
1130+
1131+
Value intValue = addOp.getLhs();
1132+
Value offsetValue = addOp.getRhs();
1133+
1134+
// Get the element size from the result pointer type
1135+
auto resultType = intToPtrOp.getType();
1136+
auto ptrType = cast<PointerType>(getElementTypeOrSelf(resultType));
1137+
int64_t elemSizeBits = triton::getPointeeBitWidth(ptrType);
1138+
int64_t elemSizeBytes = std::max<int64_t>(1, elemSizeBits / 8);
1139+
1140+
// Check if offset is a constant (either directly or via splat)
1141+
// Only apply canonicalization for constant offsets
1142+
std::optional<int64_t> constantByteOffset;
1143+
if (auto constOp = offsetValue.getDefiningOp<arith::ConstantOp>()) {
1144+
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
1145+
constantByteOffset = intAttr.getValue().getSExtValue();
1146+
} else if (auto splatAttr =
1147+
dyn_cast<SplatElementsAttr>(constOp.getValue())) {
1148+
constantByteOffset =
1149+
splatAttr.getSplatValue<IntegerAttr>().getValue().getSExtValue();
1150+
}
1151+
}
1152+
1153+
if (!constantByteOffset.has_value())
1154+
return failure(); // Only handle constant offsets
1155+
1156+
// Check if the byte offset is divisible by element size
1157+
if (constantByteOffset.value() % elemSizeBytes != 0)
1158+
return failure();
1159+
1160+
// Compute element offset at compile time
1161+
int64_t elementOffset = constantByteOffset.value() / elemSizeBytes;
1162+
1163+
// Create int_to_ptr(val) for the base
1164+
auto loc = intToPtrOp.getLoc();
1165+
Value basePtr = IntToPtrOp::create(rewriter, loc, resultType, intValue);
1166+
1167+
// Create the element offset constant
1168+
Value elementOffsetValue;
1169+
1170+
// Get the integer type from the offset value to match its type
1171+
Type offsetElemType;
1172+
if (auto tensorType = dyn_cast<RankedTensorType>(offsetValue.getType())) {
1173+
offsetElemType = tensorType.getElementType();
1174+
} else {
1175+
offsetElemType = offsetValue.getType();
1176+
}
1177+
1178+
if (auto tensorType = dyn_cast<RankedTensorType>(resultType)) {
1179+
// Create a splat constant for tensor types, matching the offset's type
1180+
auto offsetAttr = rewriter.getIntegerAttr(offsetElemType, elementOffset);
1181+
auto splatType = RankedTensorType::get(
1182+
tensorType.getShape(), offsetElemType, tensorType.getEncoding());
1183+
auto splatAttr = SplatElementsAttr::get(splatType, offsetAttr);
1184+
elementOffsetValue = arith::ConstantOp::create(rewriter, loc, splatAttr);
1185+
} else {
1186+
// Scalar case
1187+
elementOffsetValue = arith::ConstantOp::create(
1188+
rewriter, loc,
1189+
rewriter.getIntegerAttr(offsetElemType, elementOffset));
1190+
}
1191+
1192+
// Replace with addptr(int_to_ptr(val), element_offset)
1193+
rewriter.replaceOpWithNewOp<AddPtrOp>(intToPtrOp, resultType, basePtr,
1194+
elementOffsetValue);
1195+
return success();
1196+
}
1197+
};
1198+
1199+
void IntToPtrOp::getCanonicalizationPatterns(RewritePatternSet &results,
1200+
MLIRContext *context) {
1201+
results.add<CanonicalizeIntToPtrOfPtrToInt, CanonicalizeIntToPtrWithAdd>(
1202+
context);
1203+
}
1204+
10981205
// The following ops, including `call`, `func`, and `return` are copied and
10991206
// modified from
11001207
// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp

test/Triton/canonicalize.mlir

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,76 @@ tt.func @fold_transpose_constant() -> tensor<128x16xf32> {
173173
// CHECK-NEXT: tt.return %[[cst]] : tensor<128x16xf32>
174174
tt.return %r : tensor<128x16xf32>
175175
}
176+
// -----
177+
178+
// CHECK-LABEL: @canonicalize_int_to_ptr_of_ptr_to_int
179+
// Test: int_to_ptr(ptr_to_int(ptr)) -> ptr (round-trip elimination)
180+
tt.func @canonicalize_int_to_ptr_of_ptr_to_int(%ptr: tensor<64x!tt.ptr<f32>>) -> tensor<64x!tt.ptr<f32>> {
181+
// CHECK-NOT: tt.ptr_to_int
182+
// CHECK-NOT: tt.int_to_ptr
183+
// CHECK: tt.return %{{.*}} : tensor<64x!tt.ptr<f32>>
184+
%int = tt.ptr_to_int %ptr : tensor<64x!tt.ptr<f32>> -> tensor<64xi64>
185+
%result = tt.int_to_ptr %int : tensor<64xi64> -> tensor<64x!tt.ptr<f32>>
186+
tt.return %result : tensor<64x!tt.ptr<f32>>
187+
}
188+
189+
// -----
190+
191+
// CHECK-LABEL: @canonicalize_int_to_ptr_with_constant_offset_f32
192+
// Test: int_to_ptr(addi(ptr_to_int(ptr), constant)) -> addptr(ptr, element_offset)
193+
// For f32 (4 bytes): 16 bytes = 4 elements
194+
tt.func @canonicalize_int_to_ptr_with_constant_offset_f32(%base: tensor<128x!tt.ptr<f32>>) -> tensor<128x!tt.ptr<f32>> {
195+
// CHECK: %[[OFFSET:.*]] = arith.constant dense<4> : tensor<128xi64>
196+
// CHECK-NEXT: %[[RESULT:.*]] = tt.addptr %{{.*}}, %[[OFFSET]] : tensor<128x!tt.ptr<f32>>, tensor<128xi64>
197+
%byte_offset = arith.constant dense<16> : tensor<128xi64>
198+
%ptr_as_int = tt.ptr_to_int %base : tensor<128x!tt.ptr<f32>> -> tensor<128xi64>
199+
%offset_ptr_int = arith.addi %ptr_as_int, %byte_offset : tensor<128xi64>
200+
%result = tt.int_to_ptr %offset_ptr_int : tensor<128xi64> -> tensor<128x!tt.ptr<f32>>
201+
// CHECK-NEXT: tt.return %[[RESULT]] : tensor<128x!tt.ptr<f32>>
202+
tt.return %result : tensor<128x!tt.ptr<f32>>
203+
}
204+
205+
// -----
206+
207+
// CHECK-LABEL: @canonicalize_int_to_ptr_with_constant_offset_f16
208+
// Test: For f16 (2 bytes): 32 bytes = 16 elements
209+
tt.func @canonicalize_int_to_ptr_with_constant_offset_f16(%base: tensor<1024x!tt.ptr<f16>>) -> tensor<1024x!tt.ptr<f16>> {
210+
// CHECK: %[[OFFSET:.*]] = arith.constant dense<16> : tensor<1024xi64>
211+
// CHECK-NEXT: %[[RESULT:.*]] = tt.addptr %{{.*}}, %[[OFFSET]] : tensor<1024x!tt.ptr<f16>>, tensor<1024xi64>
212+
%byte_offset = arith.constant dense<32> : tensor<1024xi64>
213+
%ptr_as_int = tt.ptr_to_int %base : tensor<1024x!tt.ptr<f16>> -> tensor<1024xi64>
214+
%offset_ptr_int = arith.addi %ptr_as_int, %byte_offset : tensor<1024xi64>
215+
%result = tt.int_to_ptr %offset_ptr_int : tensor<1024xi64> -> tensor<1024x!tt.ptr<f16>>
216+
// CHECK-NEXT: tt.return %[[RESULT]] : tensor<1024x!tt.ptr<f16>>
217+
tt.return %result : tensor<1024x!tt.ptr<f16>>
218+
}
219+
220+
// -----
221+
222+
// CHECK-LABEL: @no_canonicalize_non_constant_offset
223+
// Test: Non-constant offsets should not be canonicalized
224+
tt.func @no_canonicalize_non_constant_offset(%base: tensor<128x!tt.ptr<f32>>, %offset: tensor<128xi64>) -> tensor<128x!tt.ptr<f32>> {
225+
// CHECK: tt.ptr_to_int
226+
// CHECK-NEXT: arith.addi
227+
// CHECK-NEXT: tt.int_to_ptr
228+
%ptr_as_int = tt.ptr_to_int %base : tensor<128x!tt.ptr<f32>> -> tensor<128xi64>
229+
%offset_ptr_int = arith.addi %ptr_as_int, %offset : tensor<128xi64>
230+
%result = tt.int_to_ptr %offset_ptr_int : tensor<128xi64> -> tensor<128x!tt.ptr<f32>>
231+
tt.return %result : tensor<128x!tt.ptr<f32>>
232+
}
233+
234+
// -----
235+
236+
// CHECK-LABEL: @no_canonicalize_indivisible_offset
237+
// Test: Offset not divisible by element size should not be canonicalized
238+
tt.func @no_canonicalize_indivisible_offset(%base: tensor<128x!tt.ptr<f32>>) -> tensor<128x!tt.ptr<f32>> {
239+
// 7 bytes is not divisible by 4 (size of f32)
240+
// CHECK: tt.ptr_to_int
241+
// CHECK-NEXT: arith.addi
242+
// CHECK-NEXT: tt.int_to_ptr
243+
%byte_offset = arith.constant dense<7> : tensor<128xi64>
244+
%ptr_as_int = tt.ptr_to_int %base : tensor<128x!tt.ptr<f32>> -> tensor<128xi64>
245+
%offset_ptr_int = arith.addi %ptr_as_int, %byte_offset : tensor<128xi64>
246+
%result = tt.int_to_ptr %offset_ptr_int : tensor<128xi64> -> tensor<128x!tt.ptr<f32>>
247+
tt.return %result : tensor<128x!tt.ptr<f32>>
248+
}

0 commit comments

Comments
 (0)