|
10 | 10 | #include "lib/Utils/Layout/Utils.h" |
11 | 11 | #include "llvm/include/llvm/ADT/DynamicAPInt.h" // from @llvm-project |
12 | 12 | #include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project |
| 13 | +#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project |
13 | 14 | #include "llvm/include/llvm/Support/Debug.h" // from @llvm-project |
14 | 15 | #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project |
15 | 16 | #include "mlir/include/mlir/Analysis/FlatLinearValueConstraints.h" // from @llvm-project |
@@ -48,7 +49,8 @@ static std::string printRelation(const IntegerRelation& rel) { |
48 | 49 | static FailureOr<Value> implementAssignLayoutNew( |
49 | 50 | Value input, LayoutAttr layout, int64_t ciphertextSize, |
50 | 51 | ImplicitLocOpBuilder& builder, |
51 | | - const std::function<void(Operation*)>& createdOpCallback) { |
| 52 | + const std::function<void(Operation*)>& createdOpCallback, |
| 53 | + ArrayRef<int64_t> domainSchedule = {}) { |
52 | 54 | IntegerRelation rel = layout.getIntegerRelation(); |
53 | 55 |
|
54 | 56 | RankedTensorType dataSemanticType = |
@@ -106,7 +108,10 @@ static FailureOr<Value> implementAssignLayoutNew( |
106 | 108 | TypedValue<RankedTensorType> ciphertextTensor = |
107 | 109 | cast<TypedValue<RankedTensorType>>(zeroOp.getResult()); |
108 | 110 | MLIRLoopNestGenerator generator(builder, createdOpCallback); |
109 | | - auto loopNestCstr = generateLoopNestAsCStr(rel); |
| 111 | + |
| 112 | + SmallVector<int> domainIndices = llvm::to_vector(llvm::map_range( |
| 113 | + domainSchedule, [](int64_t idx) { return static_cast<int>(idx); })); |
| 114 | + auto loopNestCstr = generateLoopNestAsCStr(rel, domainIndices); |
110 | 115 | if (failed(loopNestCstr)) { |
111 | 116 | return builder.emitError() << "Failed to generate loop nest for relation " |
112 | 117 | << printRelation(rel); |
@@ -141,7 +146,8 @@ static FailureOr<Value> implementAssignLayoutNew( |
141 | 146 | auto inserted = tensor::InsertOp::create(builder, loc, extracted, |
142 | 147 | iterArgs[0], insertIndices); |
143 | 148 | return scf::ValueVector({inserted}); |
144 | | - }); |
| 149 | + }, |
| 150 | + domainIndices); |
145 | 151 | if (failed(loop)) { |
146 | 152 | return builder.emitError() << "Failed to generate loop nest for relation " |
147 | 153 | << printRelation(rel); |
@@ -293,11 +299,12 @@ static FailureOr<Value> implementAssignLayoutPermutation( |
293 | 299 | FailureOr<Value> implementAssignLayout( |
294 | 300 | Value input, Attribute layout, int64_t ciphertextSize, |
295 | 301 | ImplicitLocOpBuilder& builder, |
296 | | - const std::function<void(Operation*)>& createdOpCallback) { |
| 302 | + const std::function<void(Operation*)>& createdOpCallback, |
| 303 | + ArrayRef<int64_t> domainSchedule) { |
297 | 304 | OpBuilder::InsertionGuard guard(builder); |
298 | 305 | if (LayoutAttr layoutAttr = dyn_cast<LayoutAttr>(layout)) { |
299 | 306 | return implementAssignLayoutNew(input, layoutAttr, ciphertextSize, builder, |
300 | | - createdOpCallback); |
| 307 | + createdOpCallback, domainSchedule); |
301 | 308 | } else if (DenseIntElementsAttr elementAttr = |
302 | 309 | dyn_cast<DenseIntElementsAttr>(layout)) { |
303 | 310 | return implementAssignLayoutPermutation(input, elementAttr, ciphertextSize, |
|
0 commit comments