Skip to content

Commit c7fc1e3

Browse files
[Analysis] divisibility handling for dividing by a power-of-two constant (#6657)
Given this ttir: ``` tt.func @div(%arg0: i32 {tt.divisibility = 16 : i32}) { %1 = arith.constant 2 : i32 %2 = arith.divsi %arg0, %1 : i32 } ``` Triton should be able to infer that %2 has divisibility = 8. Use case: a kernel in which I load with a `mask = offsets < numel // 2`.
1 parent 415c3ac commit c7fc1e3

2 files changed

Lines changed: 25 additions & 1 deletion

File tree

lib/Analysis/AxisInfo.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,13 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
467467
if (rhs.getConstantValue().has_value() &&
468468
rhs.getConstantValue().value() == 1)
469469
return lhs.getDivisibility(dim);
470+
// Case 3: lhs has contiguity of 1 in this dimension and rhs is a power of 2
471+
if (rhs.getConstantValue().has_value() &&
472+
llvm::isPowerOf2_64(std::abs(rhs.getConstantValue().value())) &&
473+
lhs.getContiguity(dim) == 1) {
474+
int64_t absRhs = std::abs(rhs.getConstantValue().value());
475+
return std::max<int64_t>(1, lhs.getDivisibility(dim) / absRhs);
476+
}
470477
// otherwise: return 1
471478
return 1;
472479
}

test/Analysis/test-alignment.mlir

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ tt.func @mul(%arg0: i64 {tt.divisibility = 16 : i32}) {
129129

130130
// -----
131131

132-
tt.func @div() {
132+
tt.func @div(%arg0: i32 {tt.divisibility = 16 : i32}) {
133133
// expected-remark @below {{contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>}}
134134
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
135135
// expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1}}
@@ -154,6 +154,23 @@ tt.func @div() {
154154
%10 = tt.make_range {end = 8320 : i32, start = 8192 : i32} : tensor<128xi32>
155155
// expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [64], constant_value = <none>}}
156156
%11 = arith.divsi %10, %4 : tensor<128xi32>
157+
// expected-remark @below {{contiguity = [1], divisibility = [2], constancy = [1], constant_value = 2}}
158+
%12 = arith.constant 2 : i32
159+
// dividing a scalar by a power of two should give predictable divisibility
160+
// expected-remark @below {{contiguity = [1], divisibility = [8], constancy = [1], constant_value = <none>}}
161+
%13 = arith.divsi %arg0, %12 : i32
162+
// expected-remark @below {{contiguity = [1], divisibility = [32], constancy = [1], constant_value = 32}}
163+
%14 = arith.constant 32 : i32
164+
// expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>}}
165+
%15 = arith.divsi %arg0, %14 : i32
166+
// expected-remark @below {{contiguity = [1], divisibility = [2], constancy = [1], constant_value = 6}}
167+
%16 = arith.constant 6 : i32
168+
// expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>}}
169+
%17 = arith.divsi %arg0, %16 : i32
170+
// expected-remark @below {{contiguity = [1], divisibility = [2], constancy = [128], constant_value = 2}}
171+
%18 = arith.constant dense<2> : tensor<128xi32>
172+
// expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [2], constant_value = <none>}}
173+
%19 = arith.divsi %0, %18 : tensor<128xi32>
157174
tt.return
158175
}
159176

0 commit comments

Comments
 (0)