Skip to content

Commit 53c5d14

Browse files
committed
Added pattern for parallel to seq for loops
1 parent cb34836 commit 53c5d14

File tree

1 file changed

+109
-10
lines changed

1 file changed

+109
-10
lines changed

lib/polygeist/Passes/RaiseToLinalg.cpp

Lines changed: 109 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,86 @@ struct AffineParallelFission : public OpRewritePattern<AffineParallelOp> {
12211221
}
12221222
};
12231223

1224+
struct AffineParallelToFor : public OpRewritePattern<AffineParallelOp> {
1225+
using OpRewritePattern<AffineParallelOp>::OpRewritePattern;
1226+
1227+
LogicalResult matchAndRewrite(AffineParallelOp parallelOp,
1228+
PatternRewriter &rewriter) const override {
1229+
1230+
// Skip if there are reductions - they need special handling
1231+
if (!parallelOp.getReductions().empty()) {
1232+
return failure();
1233+
}
1234+
1235+
// Skip if there are result types - parallel loops with returns need special handling
1236+
if (!parallelOp.getResultTypes().empty()) {
1237+
return failure();
1238+
}
1239+
1240+
Location loc = parallelOp.getLoc();
1241+
1242+
// Get the bounds and steps
1243+
auto lowerBounds = parallelOp.getLowerBoundsMap();
1244+
auto upperBounds = parallelOp.getUpperBoundsMap();
1245+
auto steps = parallelOp.getSteps();
1246+
auto lowerOperands = parallelOp.getLowerBoundsOperands();
1247+
auto upperOperands = parallelOp.getUpperBoundsOperands();
1248+
auto ivs = parallelOp.getIVs();
1249+
1250+
// Start building nested for loops from outermost to innermost
1251+
OpBuilder::InsertionGuard guard(rewriter);
1252+
rewriter.setInsertionPoint(parallelOp);
1253+
1254+
// Create nested affine.for loops
1255+
SmallVector<AffineForOp> forOps;
1256+
SmallVector<Value> newIVs;
1257+
1258+
for (unsigned i = 0; i < ivs.size(); ++i) {
1259+
// Extract bounds for this dimension
1260+
auto lbMap = lowerBounds.getSliceMap(i, 1);
1261+
auto ubMap = upperBounds.getSliceMap(i, 1);
1262+
int64_t step = steps[i];
1263+
1264+
auto forOp = rewriter.create<AffineForOp>(
1265+
loc,
1266+
lowerOperands, lbMap,
1267+
upperOperands, ubMap,
1268+
step
1269+
);
1270+
1271+
forOps.push_back(forOp);
1272+
newIVs.push_back(forOp.getInductionVar());
1273+
1274+
// Set insertion point for next loop or body
1275+
rewriter.setInsertionPointToStart(forOp.getBody());
1276+
}
1277+
1278+
// Move the body content from parallel to innermost for loop
1279+
Block *parallelBody = parallelOp.getBody();
1280+
Block *targetBody = forOps.empty() ? nullptr : forOps.back().getBody();
1281+
1282+
if (!targetBody) {
1283+
return failure();
1284+
}
1285+
1286+
// Create mapping for induction variables
1287+
IRMapping mapping;
1288+
for (auto [parallelIV, newIV] : llvm::zip(ivs, newIVs)) {
1289+
mapping.map(parallelIV, newIV);
1290+
}
1291+
1292+
// Clone operations from parallel body to for body (excluding terminator)
1293+
for (auto &op : parallelBody->without_terminator()) {
1294+
rewriter.clone(op, mapping);
1295+
}
1296+
1297+
// Remove the original parallel loop
1298+
rewriter.eraseOp(parallelOp);
1299+
1300+
return success();
1301+
}
1302+
};
1303+
12241304
// namespace {
12251305
// struct RaiseAffineToLinalg
12261306
// : public AffineRaiseToLinalgBase<RaiseAffineToLinalg> {
@@ -1255,18 +1335,37 @@ struct RaiseAffineToLinalg
12551335
} // namespace
12561336

12571337
void RaiseAffineToLinalg::runOnOperation() {
1258-
RewritePatternSet patterns(&getContext());
1259-
// TODO add the existing canonicalization patterns
1260-
// + subview of an affine apply -> subview
1338+
GreedyRewriteConfig config;
12611339

1262-
// Add the fission pattern first (preprocessing step)
1263-
patterns.insert<AffineParallelFission>(&getContext());
1340+
// Step 1: Apply fission pattern first
1341+
{
1342+
RewritePatternSet fissionPatterns(&getContext());
1343+
fissionPatterns.insert<AffineParallelFission>(&getContext());
1344+
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(fissionPatterns), config))) {
1345+
signalPassFailure();
1346+
return;
1347+
}
1348+
}
12641349

1265-
// Then add the main raising pattern
1266-
patterns.insert<AffineForOpRaising>(&getContext());
1267-
GreedyRewriteConfig config;
1268-
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
1269-
config);
1350+
// Step 2: Apply parallel-to-for conversion
1351+
{
1352+
RewritePatternSet parallelToForPatterns(&getContext());
1353+
parallelToForPatterns.insert<AffineParallelToFor>(&getContext());
1354+
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(parallelToForPatterns), config))) {
1355+
signalPassFailure();
1356+
return;
1357+
}
1358+
}
1359+
1360+
// Step 3: Apply raising pattern
1361+
{
1362+
RewritePatternSet raisingPatterns(&getContext());
1363+
raisingPatterns.insert<AffineForOpRaising>(&getContext());
1364+
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(raisingPatterns), config))) {
1365+
signalPassFailure();
1366+
return;
1367+
}
1368+
}
12701369
}
12711370

12721371
namespace mlir {

0 commit comments

Comments
 (0)