Skip to content

Commit 24c109e

Browse files
committed
Fix TLE DSL region inline extract op parent constraints
1 parent 53fcdba commit 24c109e

File tree

3 files changed

+27
-8
lines changed

3 files changed

+27
-8
lines changed

third_party/tle/dialect/include/IR/TleOps.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,50 +112,50 @@ def Tle_YieldOp : Tle_Op<"yield", [Pure, Terminator, ReturnLike,
112112
}
113113

114114
def Tle_ExtractAllocatedPtrOp
115-
: Tle_Op<"extract_allocated_ptr", [Pure, HasParent<"DSLRegionOp">]> {
115+
: Tle_Op<"extract_allocated_ptr", [Pure]> {
116116
let arguments = (ins Tle_ArgType:$input);
117117
let results = (outs LLVMPointerType:$ptr);
118118
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($ptr)";
119119
}
120120

121121
def Tle_ExtractAlignedPtrOp
122-
: Tle_Op<"extract_aligned_ptr", [Pure, HasParent<"DSLRegionOp">]> {
122+
: Tle_Op<"extract_aligned_ptr", [Pure]> {
123123
let arguments = (ins Tle_ArgType:$input);
124124
let results = (outs LLVMPointerType:$ptr);
125125
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($ptr)";
126126
}
127127

128128
def Tle_ExtractOffsetOp
129-
: Tle_Op<"extract_offset", [Pure, HasParent<"DSLRegionOp">]> {
129+
: Tle_Op<"extract_offset", [Pure]> {
130130
let arguments = (ins Tle_TensorType:$input);
131131
let results = (outs I64:$ptr);
132132
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($ptr)";
133133
}
134134

135135
def Tle_ExtractSizesOp
136-
: Tle_Op<"extract_sizes", [Pure, HasParent<"DSLRegionOp">]> {
136+
: Tle_Op<"extract_sizes", [Pure]> {
137137
let arguments = (ins Tle_TensorType:$input);
138138
let results = (outs Variadic<I64>:$sizes);
139139
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($sizes)";
140140
let builders = [OpBuilder<(ins "size_t":$num, "Value":$input)>];
141141
}
142142

143143
def Tle_ExtractStridesOp
144-
: Tle_Op<"extract_strides", [Pure, HasParent<"DSLRegionOp">]> {
144+
: Tle_Op<"extract_strides", [Pure]> {
145145
let arguments = (ins Tle_TensorType:$input);
146146
let results = (outs Variadic<I64>:$strides);
147147
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($strides)";
148148
let builders = [OpBuilder<(ins "size_t":$num, "Value":$input)>];
149149
}
150150

151-
def Tle_ExtractPtrOp : Tle_Op<"extract_ptr", [Pure, HasParent<"DSLRegionOp">]> {
151+
def Tle_ExtractPtrOp : Tle_Op<"extract_ptr", [Pure]> {
152152
let arguments = (ins TT_Ptr:$input);
153153
let results = (outs LLVMPointerType:$ptr);
154154
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($ptr)";
155155
}
156156

157157
def Tle_PackOp
158-
: Tle_Op<"pack", [MemDescViewTrait, Pure, HasParent<"DSLRegionOp">]> {
158+
: Tle_Op<"pack", [MemDescViewTrait, Pure]> {
159159
let arguments = (ins LLVMStructType:$input);
160160
let results = (outs Tle_TensorType:$output);
161161
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";

third_party/tle/dialect/lib/Conversion/TleToLLVM/PackOpToLLVM.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ PackOpConversion::PackOpConversion(LLVMTypeConverter &typeConverter,
2828
LogicalResult
2929
PackOpConversion::matchAndRewrite(tle::PackOp op, OpAdaptor adaptor,
3030
ConversionPatternRewriter &rewriter) const {
31-
auto regionOp = op->getParentOfType<tle::DSLRegionOp>();
3231
if (ttg::MemDescType memdesc =
3332
dyn_cast<ttg::MemDescType>(op.getOutput().getType())) {
3433
LLVM::LLVMStructType llvmStructType =
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: triton-opt %s --tle-dslregion-inline | FileCheck %s
2+
3+
module {
4+
llvm.func @_sink(!llvm.ptr)
5+
6+
tt.func @k(%arg0: !tt.ptr<i32>) {
7+
%0 = "tle.dsl_region"(%arg0) ({
8+
^bb0(%in: !tt.ptr<i32>):
9+
%p = "tle.extract_ptr"(%in) : (!tt.ptr<i32>) -> !llvm.ptr
10+
"tle.yield"(%p) : (!llvm.ptr) -> ()
11+
}) : (!tt.ptr<i32>) -> (!llvm.ptr)
12+
llvm.call @_sink(%0) : (!llvm.ptr) -> ()
13+
tt.return
14+
}
15+
}
16+
17+
// CHECK-LABEL: tt.func @k(
18+
// CHECK-NOT: tle.dsl_region
19+
// CHECK: %[[P:.*]] = tle.extract_ptr
20+
// CHECK: llvm.call @_sink(%[[P]]) : (!llvm.ptr) -> ()

0 commit comments

Comments
 (0)