Skip to content

Commit 310c668

Browse files
committed
Address PR feedback
Signed-off-by: Lukas Sommer <lukas.sommer@amd.com>
1 parent a76d47d commit 310c668

2 files changed

Lines changed: 67 additions & 35 deletions

File tree

compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
#define DEBUG_TYPE "iree-codegen-vector-tile-size-analysis"
2222

2323
// The purpose of this analysis is to propagate information about the
24-
// undistributed vector tile size across the operation graph. The vector tile
25-
// size is important information for the vectorization of operations.
26-
// For example, the vector tile size can be used by GenericVectorization to
27-
// introduce the necessary masking in the presence of padding/masking.
24+
// vector tile size across the operation graph. The vector tile size is
25+
// important information for the vectorization of operations. For example, the
26+
// vector tile size can be used by GenericVectorization to introduce the
27+
// necessary masking in the presence of padding/masking.
2828
//
2929
// The analysis is a bi-directional dataflow analysis building on top of the
3030
// upstream MLIR dataflow analysis framework. To implement the bi-directional
@@ -43,7 +43,8 @@
4343
// As the set union can not result in a conflict, no lattice state for top
4444
// (overdefined) is required in this lattice.
4545
//
46-
// The lattice is initialized from `to_layout` operations.
46+
// The lattice is initialized from anchor operations that provide information
47+
// about vector tile size (e.g., `to_layout`).
4748
//
4849
// Forward propagation and backward propagation work similarly:
4950
// - For elementwise operations, candidates from the different operands
@@ -204,19 +205,14 @@ static bool isDuplicatable(Value val) {
204205
if (defOp->hasTrait<OpTrait::ConstantLike>()) {
205206
return true;
206207
}
207-
// Catches linalg.fill that has been lowered/fused into linalg.generic form
208-
// (scalar input broadcast into tensor.empty output).
209-
if (auto genericOp = dyn_cast<linalg::GenericOp>(defOp)) {
210-
if (genericOp.getNumDpsInputs() == 1 && genericOp.getNumDpsInits() == 1 &&
211-
!isa<ShapedType>(genericOp.getDpsInputs()[0].getType())) {
212-
Value init = genericOp.getDpsInits()[0];
213-
if (init.getDefiningOp<tensor::EmptyOp>()) {
214-
return true;
215-
}
216-
}
217-
}
218-
if (auto fillOp = dyn_cast<linalg::FillOp>(defOp)) {
219-
if (fillOp.getOutputs()[0].getDefiningOp<tensor::EmptyOp>()) {
208+
// A linalg op that doesn't read any tensor data (e.g., linalg.fill or a
209+
// fill-like linalg.generic broadcasting a scalar) is a generator and
210+
// duplicatable.
211+
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(defOp)) {
212+
if (llvm::none_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
213+
return isa<ShapedType>(operand.get().getType()) &&
214+
linalgOp.payloadUsesValueFromOperand(&operand);
215+
})) {
220216
return true;
221217
}
222218
}
@@ -258,20 +254,6 @@ class TileSizeForwardAnalysis
258254
public:
259255
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
260256

261-
LogicalResult initialize(Operation *top) override {
262-
// Seed to_layout anchors before the regular initialization. This ensures
263-
// seeds are set even for to_layout ops in regions that DeadCodeAnalysis
264-
// hasn't yet marked as live during init.
265-
top->walk([&](ToLayoutOp toLayout) {
266-
LDBG() << "Anchor: " << toLayout;
267-
auto candidates = TileSizeCandidates::fromSizes(
268-
toLayout.getLayout().getUndistributedShape());
269-
auto *lattice = getLatticeElement(toLayout.getResult());
270-
propagateIfChanged(lattice, lattice->join(candidates));
271-
});
272-
return SparseForwardDataFlowAnalysis::initialize(top);
273-
}
274-
275257
void setToEntryState(TileSizeLattice *lattice) override {
276258
// Entry state is uninitialized (identity for join).
277259
propagateIfChanged(lattice, lattice->join(TileSizeCandidates()));
@@ -280,9 +262,12 @@ class TileSizeForwardAnalysis
280262
LogicalResult visitOperation(Operation *op,
281263
ArrayRef<const TileSizeLattice *> operands,
282264
ArrayRef<TileSizeLattice *> results) override {
283-
// to_layout: don't propagate operand forward (anchor boundary).
284-
// Seeding is done in initialize().
285-
if (isa<ToLayoutOp>(op)) {
265+
// to_layout: seed from layout, don't propagate operand forward.
266+
if (auto toLayout = dyn_cast<ToLayoutOp>(op)) {
267+
LDBG() << "Anchor: " << toLayout;
268+
auto candidates = TileSizeCandidates::fromSizes(
269+
toLayout.getLayout().getUndistributedShape());
270+
propagateIfChanged(results[0], results[0]->join(candidates));
286271
return success();
287272
}
288273

compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,53 @@ func.func @chain_propagation_transpose(
7979

8080
// -----
8181

82+
// Chain propagation with dynamic shapes: tile sizes propagate the same way
83+
// regardless of whether tensor dimensions are static or dynamic.
84+
85+
#layout_dyn = #iree_vector_ext.nested_layout<
86+
subgroup_tile = [1, 1], batch_tile = [1, 8], outer_tile = [1, 1],
87+
thread_tile = [1, 1], element_tile = [8, 8],
88+
subgroup_strides = [0, 0], thread_strides = [0, 0]>
89+
90+
// CHECK-LABEL: @chain_propagation_dynamic
91+
func.func @chain_propagation_dynamic(
92+
%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
93+
%c0 = arith.constant 0 : index
94+
%c1 = arith.constant 1 : index
95+
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
96+
%d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
97+
%a = iree_vector_ext.to_layout %arg0 to layout(#layout_dyn) : tensor<?x?xf32>
98+
%empty_ab = tensor.empty(%d0, %d1) : tensor<?x?xf32>
99+
// CHECK: linalg.generic
100+
// CHECK-SAME: iree_codegen.vector_tile_sizes = [array<i64: 8>, array<i64: 64>]
101+
%ab = linalg.generic {
102+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
103+
affine_map<(d0, d1) -> (d0, d1)>,
104+
affine_map<(d0, d1) -> (d0, d1)>],
105+
iterator_types = ["parallel", "parallel"]
106+
} ins(%a, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
107+
outs(%empty_ab : tensor<?x?xf32>) {
108+
^bb0(%in0: f32, %in1: f32, %out: f32):
109+
%add = arith.addf %in0, %in1 : f32
110+
linalg.yield %add : f32
111+
} -> tensor<?x?xf32>
112+
%empty_c = tensor.empty(%d0, %d1) : tensor<?x?xf32>
113+
// CHECK: linalg.generic
114+
// CHECK-SAME: iree_codegen.vector_tile_sizes = [array<i64: 8>, array<i64: 64>]
115+
%result = linalg.generic {
116+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
117+
affine_map<(d0, d1) -> (d0, d1)>],
118+
iterator_types = ["parallel", "parallel"]
119+
} ins(%ab : tensor<?x?xf32>) outs(%empty_c : tensor<?x?xf32>) {
120+
^bb0(%in: f32, %out: f32):
121+
%neg = arith.negf %in : f32
122+
linalg.yield %neg : f32
123+
} -> tensor<?x?xf32>
124+
return %result : tensor<?x?xf32>
125+
}
126+
127+
// -----
128+
82129
// scf.for propagation through iter_args.
83130
// The to_layout inside the loop should propagate tile sizes to the
84131
// loop iter_args and through the scf.yield.

0 commit comments

Comments
 (0)