Skip to content

Commit 7e114c3

Browse files
asraacopybara-github
authored andcommitted
Add domain schedule attribute to tensor_ext.assign_layout and convert_layout
* utilize this for a domain schedule for the 2d conv nchw fchw * update implementAssignLayout to take the domain schedule and pass that into the loop generation string * add syntax and folder tests PiperOrigin-RevId: 906428225
1 parent d855965 commit 7e114c3

9 files changed

Lines changed: 124 additions & 21 deletions

File tree

lib/Dialect/TensorExt/IR/TensorExtOps.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,29 @@ LogicalResult ConvertLayoutOp::verify() {
198198
}
199199

200200
LogicalResult AssignLayoutOp::verify() {
201-
return verifyLayoutMatchesType(getLayout(), getValue().getType(), *this);
201+
LogicalResult layoutVerification =
202+
verifyLayoutMatchesType(getLayout(), getValue().getType(), *this);
203+
if (failed(layoutVerification)) {
204+
return layoutVerification;
205+
}
206+
207+
if (!getDomainSchedule().empty()) {
208+
auto layout = dyn_cast<LayoutAttr>(getLayout());
209+
if (!layout) {
210+
return emitOpError()
211+
<< "requires LayoutAttr when domainSchedule is provided";
212+
}
213+
presburger::IntegerRelation rel = layout.getIntegerRelation();
214+
for (int64_t idx : getDomainSchedule()) {
215+
if (idx < 0 || idx >= rel.getNumDomainVars()) {
216+
return emitOpError()
217+
<< "domainSchedule index " << idx << " is out of bounds [0, "
218+
<< rel.getNumDomainVars() << ")";
219+
}
220+
}
221+
}
222+
223+
return success();
202224
}
203225

204226
LogicalResult UnpackOp::verify() {

lib/Dialect/TensorExt/IR/TensorExtOps.td

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,20 @@ def TensorExt_ConvertLayoutOp : TensorExt_Op<"convert_layout", [Pure, AllTypesMa
110110
plaintext masks (ciphertext-plaintext multiplications), and additions.
111111

112112
This op is inserted by layout selection passes.
113+
114+
The optional `domainSchedule` attribute specifies which extra domain
115+
dimensions should be part of the loop nest schedule when lowering. Generated
116+
loops nests are only used when the layout conversion is used for a plaintext
117+
layout assignment. See `tensor_ext.assign_layout` for more details.
113118
}];
114119

115120
let assemblyFormat = "operands attr-dict `:` type($output)";
116-
let arguments = (ins AnyType:$value, LayoutLike:$from_layout, LayoutLike:$to_layout);
121+
let arguments = (ins
122+
AnyType:$value,
123+
LayoutLike:$from_layout,
124+
LayoutLike:$to_layout,
125+
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$domainSchedule
126+
);
117127
let results = (outs AnyType:$output);
118128
let hasVerifier = 1;
119129
let hasFolder = 1;
@@ -130,10 +140,20 @@ def TensorExt_AssignLayoutOp : TensorExt_Op<"assign_layout", [Pure, AllTypesMatc
130140
lowered as ciphertext-plaintext ops.
131141

132142
This op is inserted by layout selection passes.
143+
144+
The optional `domainSchedule` attribute specifies which domain dimensions
145+
should be part of the loop nest schedule when lowering. The range dimensions
146+
are always included. Adding extra domain indices to the loop schedule will
147+
help ISL generate a loop nest more efficiently, but comes at the tradeoff of
148+
a slower generated loop nest.
133149
}];
134150

135151
let assemblyFormat = "operands attr-dict `:` type($output)";
136-
let arguments = (ins AnyType:$value, LayoutLike:$layout);
152+
let arguments = (ins
153+
AnyType:$value,
154+
LayoutLike:$layout,
155+
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$domainSchedule
156+
);
137157
let results = (outs AnyType:$output);
138158
let hasVerifier = 1;
139159
}

lib/Dialect/TensorExt/Transforms/Patterns.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
#include "lib/Dialect/TensorExt/IR/TensorExtDialect.h"
66
#include "lib/Dialect/TensorExt/IR/TensorExtOps.h"
77
#include "lib/Utils/AttributeUtils.h"
8-
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
9-
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
10-
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
8+
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
9+
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
10+
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
11+
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
1112

1213
namespace mlir {
1314
namespace heir {
@@ -33,8 +34,15 @@ LogicalResult FoldConvertLayoutIntoAssignLayoutPattern::matchAndRewrite(
3334
continue;
3435
}
3536

36-
auto newOp = rewriter.replaceOpWithNewOp<AssignLayoutOp>(
37-
convertLayoutOp, op.getValue(), convertLayoutOp.getToLayout());
37+
DenseI64ArrayAttr domainSchedule =
38+
!convertLayoutOp.getDomainSchedule().empty()
39+
? convertLayoutOp.getDomainScheduleAttr()
40+
: op.getDomainScheduleAttr();
41+
42+
auto newOp = AssignLayoutOp::create(
43+
rewriter, convertLayoutOp.getLoc(), op.getValue(),
44+
convertLayoutOp.getToLayout(), domainSchedule);
45+
rewriter.replaceOp(convertLayoutOp, newOp.getResult());
3846
// Ensure the newOp has its layout attribute properly set
3947
setAttributeAssociatedWith(newOp.getResult(),
4048
TensorExtDialect::kLayoutAttrName,

lib/Transforms/ConvertToCiphertextSemantics/AssignLayout.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "lib/Utils/Layout/Utils.h"
1111
#include "llvm/include/llvm/ADT/DynamicAPInt.h" // from @llvm-project
1212
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
13+
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
1314
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
1415
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
1516
#include "mlir/include/mlir/Analysis/FlatLinearValueConstraints.h" // from @llvm-project
@@ -48,7 +49,8 @@ static std::string printRelation(const IntegerRelation& rel) {
4849
static FailureOr<Value> implementAssignLayoutNew(
4950
Value input, LayoutAttr layout, int64_t ciphertextSize,
5051
ImplicitLocOpBuilder& builder,
51-
const std::function<void(Operation*)>& createdOpCallback) {
52+
const std::function<void(Operation*)>& createdOpCallback,
53+
ArrayRef<int64_t> domainSchedule = {}) {
5254
IntegerRelation rel = layout.getIntegerRelation();
5355

5456
RankedTensorType dataSemanticType =
@@ -106,7 +108,10 @@ static FailureOr<Value> implementAssignLayoutNew(
106108
TypedValue<RankedTensorType> ciphertextTensor =
107109
cast<TypedValue<RankedTensorType>>(zeroOp.getResult());
108110
MLIRLoopNestGenerator generator(builder, createdOpCallback);
109-
auto loopNestCstr = generateLoopNestAsCStr(rel);
111+
112+
SmallVector<int> domainIndices = llvm::to_vector(llvm::map_range(
113+
domainSchedule, [](int64_t idx) { return static_cast<int>(idx); }));
114+
auto loopNestCstr = generateLoopNestAsCStr(rel, domainIndices);
110115
if (failed(loopNestCstr)) {
111116
return builder.emitError() << "Failed to generate loop nest for relation "
112117
<< printRelation(rel);
@@ -141,7 +146,8 @@ static FailureOr<Value> implementAssignLayoutNew(
141146
auto inserted = tensor::InsertOp::create(builder, loc, extracted,
142147
iterArgs[0], insertIndices);
143148
return scf::ValueVector({inserted});
144-
});
149+
},
150+
domainIndices);
145151
if (failed(loop)) {
146152
return builder.emitError() << "Failed to generate loop nest for relation "
147153
<< printRelation(rel);
@@ -293,11 +299,12 @@ static FailureOr<Value> implementAssignLayoutPermutation(
293299
FailureOr<Value> implementAssignLayout(
294300
Value input, Attribute layout, int64_t ciphertextSize,
295301
ImplicitLocOpBuilder& builder,
296-
const std::function<void(Operation*)>& createdOpCallback) {
302+
const std::function<void(Operation*)>& createdOpCallback,
303+
ArrayRef<int64_t> domainSchedule) {
297304
OpBuilder::InsertionGuard guard(builder);
298305
if (LayoutAttr layoutAttr = dyn_cast<LayoutAttr>(layout)) {
299306
return implementAssignLayoutNew(input, layoutAttr, ciphertextSize, builder,
300-
createdOpCallback);
307+
createdOpCallback, domainSchedule);
301308
} else if (DenseIntElementsAttr elementAttr =
302309
dyn_cast<DenseIntElementsAttr>(layout)) {
303310
return implementAssignLayoutPermutation(input, elementAttr, ciphertextSize,

lib/Transforms/ConvertToCiphertextSemantics/AssignLayout.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ namespace heir {
2323
FailureOr<Value> implementAssignLayout(
2424
Value input, Attribute layout, int64_t ciphertextSize,
2525
ImplicitLocOpBuilder& builder,
26-
const std::function<void(Operation*)>& createdOpCallback);
26+
const std::function<void(Operation*)>& createdOpCallback,
27+
ArrayRef<int64_t> domainSchedule = {});
2728

2829
// Lower tensor_ext.unpack. Returns the final value produced by the unpacking
2930
// implementation. Applies createdOpCallback to each created operation.

lib/Transforms/LayoutOptimization/Patterns.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ LogicalResult tryFoldLayoutConversionIntoPrevious(
3131
auto newFrom = priorConversion.getFromLayoutAttr();
3232
auto newTo = op.getToLayoutAttr();
3333
auto newConversion = ConvertLayoutOp::create(
34-
rewriter, op.getLoc(), priorConversion.getValue(), newFrom, newTo);
34+
rewriter, op.getLoc(), priorConversion.getValue(), newFrom, newTo,
35+
op.getDomainScheduleAttr());
3536
newConversion->setAttr(kLayoutAttrName, newTo);
3637

3738
rewriter.replaceAllUsesWith(op, newConversion);
@@ -45,8 +46,12 @@ LogicalResult tryFoldLayoutConversionIntoPrevious(
4546
if (auto priorAssignment = op.getValue().getDefiningOp<AssignLayoutOp>()) {
4647
// merge the conversion into the assignment return success();
4748
auto newLayout = op.getToLayoutAttr();
48-
auto newAssign = AssignLayoutOp::create(
49-
rewriter, op.getLoc(), priorAssignment.getValue(), newLayout);
49+
auto domainSchedule = !op.getDomainSchedule().empty()
50+
? op.getDomainScheduleAttr()
51+
: priorAssignment.getDomainScheduleAttr();
52+
auto newAssign = AssignLayoutOp::create(rewriter, op.getLoc(),
53+
priorAssignment.getValue(),
54+
newLayout, domainSchedule);
5055
newAssign->setAttr(kLayoutAttrName, newLayout);
5156

5257
rewriter.replaceAllUsesWith(op, newAssign);

lib/Transforms/LayoutPropagation/LayoutPropagation.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,13 @@ void debugAssignLayout(Value value, LayoutAttr layout) {
111111

112112
std::pair<Value, LayoutAttr> convertToLayout(
113113
MLIRContext* ctx, mlir::IRRewriter& builder, Operation* op, Value value,
114-
LayoutAttr oldLayout, const IntegerRelation& newRelation) {
114+
LayoutAttr oldLayout, const IntegerRelation& newRelation,
115+
ArrayRef<int64_t> domainSchedule = {}) {
115116
builder.setInsertionPoint(op);
116117
LayoutAttr layoutAttr = LayoutAttr::getFromIntegerRelation(ctx, newRelation);
117118
ConvertLayoutOp convertLayoutOp = ConvertLayoutOp::create(
118-
builder, op->getLoc(), value, oldLayout, layoutAttr);
119+
builder, op->getLoc(), value, oldLayout, layoutAttr,
120+
builder.getDenseI64ArrayAttr(domainSchedule));
119121
convertLayoutOp->setAttr(tensor_ext::TensorExtDialect::kLayoutAttrName,
120122
layoutAttr);
121123
Value toReplace = convertLayoutOp.getResult();
@@ -722,9 +724,13 @@ LogicalResult LayoutPropagation::visitOperation(Conv2DNchwFchwOp op) {
722724
"inserting layout conversion.\n");
723725

724726
// Insert a layout conversion op to make the matrix layout expanded and
725-
// squat diagonal
727+
// squat diagonal. The added domain schedule ensures ISL can efficiently
728+
// generate a loop nest implementing the layout. However, the choice of
729+
// which domain indices to include is arbitrary (so long as ISL remains
730+
// fast).
726731
auto [toReplace, newFilterLayoutAttr] = convertToLayout(
727-
ctx, builder, op, filter, filterLayout, convRelation.value());
732+
ctx, builder, op, filter, filterLayout, convRelation.value(),
733+
/*domainSchedule=*/{0, 1});
728734
debugAssignLayout(toReplace, newFilterLayoutAttr);
729735
assignedLayouts.insert({toReplace, newFilterLayoutAttr});
730736
}

tests/Dialect/TensorExt/IR/ops.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,13 @@ func.func @test_halevi_shoup_reduction(%0: tensor<16xi32>, %1: tensor<16x16xi32>
3232
%2 = tensor_ext.rotate_and_reduce %0, %1 {period = 1 : index, steps = 16 : index} : (tensor<16xi32>, tensor<16x16xi32>) -> tensor<16xi32>
3333
return %2 : tensor<16xi32>
3434
}
35+
36+
func.func @test_convert_layout_with_schedule(%0: tensor<16x16xi32>) -> tensor<16x16xi32> {
37+
%1 = tensor_ext.convert_layout %0 {from_layout = #layout1, to_layout = #layout2, domainSchedule = array<i64: 0, 1>} : tensor<16x16xi32>
38+
return %1 : tensor<16x16xi32>
39+
}
40+
41+
func.func @test_assign_layout_with_schedule(%0: tensor<16x16xi32>) -> tensor<16x16xi32> {
42+
%1 = tensor_ext.assign_layout %0 {layout = #layout1, domainSchedule = array<i64: 1, 0>} : tensor<16x16xi32>
43+
return %1 : tensor<16x16xi32>
44+
}

tests/Transforms/fold_convert_layout_to_assign_layout/fold.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,27 @@ func.func @fold_multiple(%arg0 : tensor<16x16xi16>) -> (tensor<16x16xi16>, tenso
2424
%2 = tensor_ext.convert_layout %0 {from_layout = #row_major_matrix, to_layout = #col_major_matrix2} : tensor<16x16xi16>
2525
return %1, %2 : tensor<16x16xi16>, tensor<16x16xi16>
2626
}
27+
28+
// CHECK: @assign_layout_with_schedule
29+
func.func @assign_layout_with_schedule(%arg0 : tensor<16x16xi16>) -> tensor<16x16xi16> {
30+
// CHECK: tensor_ext.assign_layout {{.*}}domainSchedule = array<i64: 1, 0>
31+
%0 = tensor_ext.assign_layout %arg0 {layout = #row_major_matrix, domainSchedule = array<i64: 1, 0>} : tensor<16x16xi16>
32+
%1 = tensor_ext.convert_layout %0 {from_layout = #row_major_matrix, to_layout = #col_major_matrix} : tensor<16x16xi16>
33+
return %1 : tensor<16x16xi16>
34+
}
35+
36+
// CHECK: @convert_layout_with_schedule
37+
func.func @convert_layout_with_schedule(%arg0 : tensor<16x16xi16>) -> tensor<16x16xi16> {
38+
// CHECK: tensor_ext.assign_layout {{.*}}domainSchedule = array<i64: 0, 1>
39+
%0 = tensor_ext.assign_layout %arg0 {layout = #row_major_matrix} : tensor<16x16xi16>
40+
%1 = tensor_ext.convert_layout %0 {from_layout = #row_major_matrix, to_layout = #col_major_matrix, domainSchedule = array<i64: 0, 1>} : tensor<16x16xi16>
41+
return %1 : tensor<16x16xi16>
42+
}
43+
44+
// CHECK: @both_with_schedule
45+
func.func @both_with_schedule(%arg0 : tensor<16x16xi16>) -> tensor<16x16xi16> {
46+
// CHECK: tensor_ext.assign_layout {{.*}}domainSchedule = array<i64: 0, 1>
47+
%0 = tensor_ext.assign_layout %arg0 {layout = #row_major_matrix, domainSchedule = array<i64: 1, 0>} : tensor<16x16xi16>
48+
%1 = tensor_ext.convert_layout %0 {from_layout = #row_major_matrix, to_layout = #col_major_matrix, domainSchedule = array<i64: 0, 1>} : tensor<16x16xi16>
49+
return %1 : tensor<16x16xi16>
50+
}

0 commit comments

Comments
 (0)