@@ -1221,6 +1221,86 @@ struct AffineParallelFission : public OpRewritePattern<AffineParallelOp> {
12211221 }
12221222};
12231223
1224+ struct AffineParallelToFor : public OpRewritePattern <AffineParallelOp> {
1225+ using OpRewritePattern<AffineParallelOp>::OpRewritePattern;
1226+
1227+ LogicalResult matchAndRewrite (AffineParallelOp parallelOp,
1228+ PatternRewriter &rewriter) const override {
1229+
1230+ // Skip if there are reductions - they need special handling
1231+ if (!parallelOp.getReductions ().empty ()) {
1232+ return failure ();
1233+ }
1234+
1235+ // Skip if there are result types - parallel loops with returns need special handling
1236+ if (!parallelOp.getResultTypes ().empty ()) {
1237+ return failure ();
1238+ }
1239+
1240+ Location loc = parallelOp.getLoc ();
1241+
1242+ // Get the bounds and steps
1243+ auto lowerBounds = parallelOp.getLowerBoundsMap ();
1244+ auto upperBounds = parallelOp.getUpperBoundsMap ();
1245+ auto steps = parallelOp.getSteps ();
1246+ auto lowerOperands = parallelOp.getLowerBoundsOperands ();
1247+ auto upperOperands = parallelOp.getUpperBoundsOperands ();
1248+ auto ivs = parallelOp.getIVs ();
1249+
1250+ // Start building nested for loops from outermost to innermost
1251+ OpBuilder::InsertionGuard guard (rewriter);
1252+ rewriter.setInsertionPoint (parallelOp);
1253+
1254+ // Create nested affine.for loops
1255+ SmallVector<AffineForOp> forOps;
1256+ SmallVector<Value> newIVs;
1257+
1258+ for (unsigned i = 0 ; i < ivs.size (); ++i) {
1259+ // Extract bounds for this dimension
1260+ auto lbMap = lowerBounds.getSliceMap (i, 1 );
1261+ auto ubMap = upperBounds.getSliceMap (i, 1 );
1262+ int64_t step = steps[i];
1263+
1264+ auto forOp = rewriter.create <AffineForOp>(
1265+ loc,
1266+ lowerOperands, lbMap,
1267+ upperOperands, ubMap,
1268+ step
1269+ );
1270+
1271+ forOps.push_back (forOp);
1272+ newIVs.push_back (forOp.getInductionVar ());
1273+
1274+ // Set insertion point for next loop or body
1275+ rewriter.setInsertionPointToStart (forOp.getBody ());
1276+ }
1277+
1278+ // Move the body content from parallel to innermost for loop
1279+ Block *parallelBody = parallelOp.getBody ();
1280+ Block *targetBody = forOps.empty () ? nullptr : forOps.back ().getBody ();
1281+
1282+ if (!targetBody) {
1283+ return failure ();
1284+ }
1285+
1286+ // Create mapping for induction variables
1287+ IRMapping mapping;
1288+ for (auto [parallelIV, newIV] : llvm::zip (ivs, newIVs)) {
1289+ mapping.map (parallelIV, newIV);
1290+ }
1291+
1292+ // Clone operations from parallel body to for body (excluding terminator)
1293+ for (auto &op : parallelBody->without_terminator ()) {
1294+ rewriter.clone (op, mapping);
1295+ }
1296+
1297+ // Remove the original parallel loop
1298+ rewriter.eraseOp (parallelOp);
1299+
1300+ return success ();
1301+ }
1302+ };
1303+
12241304// namespace {
12251305// struct RaiseAffineToLinalg
12261306// : public AffineRaiseToLinalgBase<RaiseAffineToLinalg> {
@@ -1255,18 +1335,37 @@ struct RaiseAffineToLinalg
12551335} // namespace
12561336
12571337void RaiseAffineToLinalg::runOnOperation () {
1258- RewritePatternSet patterns (&getContext ());
1259- // TODO add the existing canonicalization patterns
1260- // + subview of an affine apply -> subview
1338+ GreedyRewriteConfig config;
12611339
1262- // Add the fission pattern first (preprocessing step)
1263- patterns.insert <AffineParallelFission>(&getContext ());
1340+ // Step 1: Apply fission pattern first
1341+ {
1342+ RewritePatternSet fissionPatterns (&getContext ());
1343+ fissionPatterns.insert <AffineParallelFission>(&getContext ());
1344+ if (failed (applyPatternsAndFoldGreedily (getOperation (), std::move (fissionPatterns), config))) {
1345+ signalPassFailure ();
1346+ return ;
1347+ }
1348+ }
12641349
1265- // Then add the main raising pattern
1266- patterns.insert <AffineForOpRaising>(&getContext ());
1267- GreedyRewriteConfig config;
1268- (void )applyPatternsAndFoldGreedily (getOperation (), std::move (patterns),
1269- config);
1350+ // Step 2: Apply parallel-to-for conversion
1351+ {
1352+ RewritePatternSet parallelToForPatterns (&getContext ());
1353+ parallelToForPatterns.insert <AffineParallelToFor>(&getContext ());
1354+ if (failed (applyPatternsAndFoldGreedily (getOperation (), std::move (parallelToForPatterns), config))) {
1355+ signalPassFailure ();
1356+ return ;
1357+ }
1358+ }
1359+
1360+ // Step 3: Apply raising pattern
1361+ {
1362+ RewritePatternSet raisingPatterns (&getContext ());
1363+ raisingPatterns.insert <AffineForOpRaising>(&getContext ());
1364+ if (failed (applyPatternsAndFoldGreedily (getOperation (), std::move (raisingPatterns), config))) {
1365+ signalPassFailure ();
1366+ return ;
1367+ }
1368+ }
12701369}
12711370
12721371namespace mlir {
0 commit comments