Skip to content

Commit 789859e

Browse files
authored
[Codegen] Use safer hoisting in OptimizeTensorInsertExtractSlices (#23280)
Use the `moveLoopInvariantCodeFromGuaranteedLoops` transform instead of the `moveLoopInvariantCode` transform in the OptimizeTensorInsertExtractSlices pass. This transform is safer, because it validates that loops will be executed at least once before hoisting loop invariant code. Hoisting from loops that may not execute is not an optimization, so this is a better version of the transformation. The new safer transform also hoists from linalg.generic ops, so the `moveLoopInvariantCodeFromGenericOps` is removed, since it is no longer used. This PR also removes the `_batch_matmul_narrow_n_2_dispatch_4_unpack_i32` test, which was doing nothing but checking that a tensor.empty op gets hoisted from an scf.for loop (which cannot be guaranteed to execute). Hoisting empty tensors is not the job of this pass, and the test is verbose, so the test is simply removed. Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
1 parent af093a8 commit 789859e

3 files changed

Lines changed: 33 additions & 62 deletions

File tree

compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -225,23 +225,6 @@ void hoistSubsetWithLoopInvariantTensor(RewriterBase &rewriter,
225225
}
226226
}
227227

228-
void moveLoopInvariantCodeFromGenericOps(Operation *op) {
229-
// linalg.generic operations are also loop-like, but they don't have
230-
// LoopLikeOpInterface implemented for them.
231-
op->walk([&](linalg::GenericOp genericOp) {
232-
moveLoopInvariantCode(
233-
&genericOp.getBodyRegion(),
234-
[&](Value value, Region *) {
235-
return !genericOp->isAncestor(value.getParentRegion()->getParentOp());
236-
},
237-
[&](Operation *op, Region *) {
238-
return !isa<linalg::IndexOp>(op) && isMemoryEffectFree(op) &&
239-
isSpeculatable(op);
240-
},
241-
[&](Operation *op, Region *) { op->moveBefore(genericOp); });
242-
});
243-
}
244-
245228
namespace {
246229
struct CastLikeExtractSliceOpFolder final
247230
: OpRewritePattern<tensor::ExtractSliceOp> {
@@ -382,11 +365,8 @@ void OptimizeTensorInsertExtractSlicesPass::runOnOperation() {
382365
extractSliceOp->moveAfter(latestInsertionPoint);
383366
});
384367

385-
funcOp.walk([&](scf::ForOp forOp) { moveLoopInvariantCode(forOp); });
386-
LDBG() << "after hoisting loop invariant code\n" << funcOp;
387-
388-
moveLoopInvariantCodeFromGenericOps(funcOp);
389-
LDBG() << "after hoisting loop invariant code out of generic ops\n" << funcOp;
368+
moveLoopInvariantCodeFromGuaranteedLoops(funcOp);
369+
LDBG() << "after hoisting loop invariant code\n" << funcOp << "\n";
390370

391371
// TODO: walking in some reverse / inside-out order would be more efficient
392372
// and would capture more cases.

compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -229,42 +229,6 @@ func.func @batch_matmul_with_padding_strategy(%arg0: tensor<1x?x1280xf16>, %arg1
229229

230230
// -----
231231

232-
#pipeline_layout = #hal.pipeline.layout<bindings = [
233-
#hal.pipeline.binding<storage_buffer>,
234-
#hal.pipeline.binding<storage_buffer>
235-
]>
236-
func.func @_batch_matmul_narrow_n_2_dispatch_4_unpack_i32() attributes {translation_info = #iree_codegen.translation_info<pipeline = CPUDataTiling>} {
237-
%c0_i32 = arith.constant 0 : i32
238-
%c2 = arith.constant 2 : index
239-
%c128 = arith.constant 128 : index
240-
%c0 = arith.constant 0 : index
241-
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c128) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x1x1x2x8xi32>>
242-
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x3x2xi32>>
243-
%workgroup_id_x = hal.interface.workgroup.id[0] : index
244-
%workgroup_count_x = hal.interface.workgroup.count[0] : index
245-
scf.for %arg0 = %workgroup_id_x to %c2 step %workgroup_count_x {
246-
%2 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [%arg0, 0, 0], sizes = [1, 3, 2], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x3x2xi32>> -> tensor<1x3x2xi32>
247-
%3 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [%arg0, 0, 0, 0, 0], sizes = [1, 1, 1, 2, 8], strides = [1, 1, 1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x1x1x2x8xi32>> -> tensor<1x1x1x2x8xi32>
248-
%4 = vector.transfer_read %3[%c0, %c0, %c0, %c0, %c0], %c0_i32 {in_bounds = [true, true]} : tensor<1x1x1x2x8xi32>, vector<2x8xi32>
249-
%5 = vector.transpose %4, [1, 0] : vector<2x8xi32> to vector<8x2xi32>
250-
%6 = tensor.empty() : tensor<3x2xi32>
251-
%7 = vector.transfer_write %5, %6[%c0, %c0] {in_bounds = [false, true]} : vector<8x2xi32>, tensor<3x2xi32>
252-
%inserted_slice = tensor.insert_slice %7 into %2[0, 0, 0] [1, 3, 2] [1, 1, 1] : tensor<3x2xi32> into tensor<1x3x2xi32>
253-
iree_tensor_ext.dispatch.tensor.store %inserted_slice, %1, offsets = [%arg0, 0, 0], sizes = [1, 3, 2], strides = [1, 1, 1] : tensor<1x3x2xi32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x3x2xi32>>
254-
}
255-
return
256-
}
257-
258-
// CHECK-LABEL: func.func @_batch_matmul_narrow_n_2_dispatch_4_unpack_i32
259-
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
260-
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x2xi32>
261-
// CHECK: scf.for
262-
// CHECK: %[[READ:.+]] = vector.transfer_read
263-
// CHECK: %[[TRANS:.+]] = vector.transpose %[[READ]], [1, 0] : vector<2x8xi32> to vector<8x2xi32>
264-
// CHECK: vector.transfer_write %[[TRANS]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [false, true]} : vector<8x2xi32>, tensor<3x2xi32>
265-
266-
// -----
267-
268232
func.func @subset_hoisting_invariant_tensor(%init: tensor<64x64xf32>, %t: tensor<64x64xf32>) -> tensor<64x64xf32> {
269233
%c0 = arith.constant 0 : index
270234
%c1 = arith.constant 1 : index
@@ -373,3 +337,30 @@ func.func @licm_generic(%source: tensor<32x32xf16>, %idx : index) -> tensor<32x3
373337
// CHECK: linalg.generic
374338
// CHECK-NOT: tensor.extract
375339
// CHECK: return
340+
341+
// -----
342+
343+
// Verify that loop invariant ops are not hoisted from regions that may not be
344+
// executed.
345+
func.func @no_hoist_from_possibly_unexecuted_region(%arg0: tensor<4x8xi32>) -> tensor<8x4xi32> {
346+
%c0_i32 = arith.constant 0 : i32
347+
%c0 = arith.constant 0 : index
348+
%c1 = arith.constant 1 : index
349+
%c100 = arith.constant 100 : index
350+
%workgroup_id_x = hal.interface.workgroup.id[0] : index
351+
%0 = tensor.empty() : tensor<8x4xi32>
352+
%1 = scf.for %arg1 = %workgroup_id_x to %c1 step %c100 iter_args(%arg2 = %0) -> tensor<8x4xi32> {
353+
%2 = vector.transfer_read %arg0[%c0, %c0], %c0_i32 {in_bounds = [true, true]} : tensor<4x8xi32>, vector<2x8xi32>
354+
%3 = vector.transpose %2, [1, 0] : vector<2x8xi32> to vector<8x2xi32>
355+
%4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<8x2xi32>, tensor<8x4xi32>
356+
scf.yield %4 : tensor<8x4xi32>
357+
}
358+
return %1 : tensor<8x4xi32>
359+
}
360+
361+
// CHECK-LABEL: func.func @no_hoist_from_possibly_unexecuted_region
362+
// CHECK: scf.for {{.*}} {
363+
// CHECK: vector.transfer_read
364+
// CHECK: vector.transpose
365+
// CHECK: vector.transfer_write
366+
// CHECK: }

compiler/src/iree/compiler/Codegen/LLVMGPU/test/winograd_pipeline_test.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ func.func @winograd_input_transform() {
4343
}
4444
// CHECK-LABEL: func.func @winograd_input_transform
4545
// CHECK-NOT: memref.alloc
46-
// CHECK: vector.transfer_read
47-
// CHECK: vector.transfer_read
4846
// CHECK: scf.for
4947
// CHECK: scf.for
48+
// CHECK: vector.transfer_read
49+
// CHECK: vector.transfer_read
5050
// CHECK: scf.for
5151
// CHECK: vector.transfer_read
5252
// CHECK: vector.contract
@@ -71,10 +71,10 @@ func.func @winograd_output_transform() {
7171
}
7272
// CHECK-LABEL: func.func @winograd_output_transform
7373
// CHECK-NOT: memref.alloc
74-
// CHECK: vector.transfer_read
75-
// CHECK: vector.transfer_read
7674
// CHECK: scf.for
7775
// CHECK: scf.for
76+
// CHECK: vector.transfer_read
77+
// CHECK: vector.transfer_read
7878
// CHECK: scf.for
7979
// CHECK: vector.transfer_read
8080
// CHECK: vector.contract

0 commit comments

Comments
 (0)