2121#define DEBUG_TYPE " iree-codegen-vector-tile-size-analysis"
2222
2323// The purpose of this analysis is to propagate information about the
24- // undistributed vector tile size across the operation graph. The vector tile
25- // size is important information for the vectorization of operations.
26- // For example, the vector tile size can be used by GenericVectorization to
27- // introduce the necessary masking in the presence of padding/masking.
24+ // vector tile size across the operation graph. The vector tile size is
25+ // important information for the vectorization of operations. For example, the
26+ // vector tile size can be used by GenericVectorization to introduce the
27+ // necessary masking in the presence of padding/masking.
2828//
2929// The analysis is a bi-directional dataflow analysis building on top of the
3030// upstream MLIR dataflow analysis framework. To implement the bi-directional
4343// As the set union can not result in a conflict, no lattice state for top
4444// (overdefined) is required in this lattice.
4545//
46- // The lattice is initialized from `to_layout` operations.
46+ // The lattice is initialized from anchor operations that provide information
47+ // about vector tile size (e.g., `to_layout`).
4748//
4849// Forward propagation and backward propagation work similarly:
4950// - For elementwise operations, candidates from the different operands
@@ -204,19 +205,14 @@ static bool isDuplicatable(Value val) {
204205 if (defOp->hasTrait <OpTrait::ConstantLike>()) {
205206 return true ;
206207 }
207- // Catches linalg.fill that has been lowered/fused into linalg.generic form
208- // (scalar input broadcast into tensor.empty output).
209- if (auto genericOp = dyn_cast<linalg::GenericOp>(defOp)) {
210- if (genericOp.getNumDpsInputs () == 1 && genericOp.getNumDpsInits () == 1 &&
211- !isa<ShapedType>(genericOp.getDpsInputs ()[0 ].getType ())) {
212- Value init = genericOp.getDpsInits ()[0 ];
213- if (init.getDefiningOp <tensor::EmptyOp>()) {
214- return true ;
215- }
216- }
217- }
218- if (auto fillOp = dyn_cast<linalg::FillOp>(defOp)) {
219- if (fillOp.getOutputs ()[0 ].getDefiningOp <tensor::EmptyOp>()) {
208+ // A linalg op that doesn't read any tensor data (e.g., linalg.fill or a
209+ // fill-like linalg.generic broadcasting a scalar) is a generator and
210+ // duplicatable.
211+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(defOp)) {
212+ if (llvm::none_of (linalgOp->getOpOperands (), [&](OpOperand &operand) {
213+ return isa<ShapedType>(operand.get ().getType ()) &&
214+ linalgOp.payloadUsesValueFromOperand (&operand);
215+ })) {
220216 return true ;
221217 }
222218 }
@@ -258,20 +254,6 @@ class TileSizeForwardAnalysis
258254public:
259255 using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
260256
261- LogicalResult initialize (Operation *top) override {
262- // Seed to_layout anchors before the regular initialization. This ensures
263- // seeds are set even for to_layout ops in regions that DeadCodeAnalysis
264- // hasn't yet marked as live during init.
265- top->walk ([&](ToLayoutOp toLayout) {
266- LDBG () << " Anchor: " << toLayout;
267- auto candidates = TileSizeCandidates::fromSizes (
268- toLayout.getLayout ().getUndistributedShape ());
269- auto *lattice = getLatticeElement (toLayout.getResult ());
270- propagateIfChanged (lattice, lattice->join (candidates));
271- });
272- return SparseForwardDataFlowAnalysis::initialize (top);
273- }
274-
275257 void setToEntryState (TileSizeLattice *lattice) override {
276258 // Entry state is uninitialized (identity for join).
277259 propagateIfChanged (lattice, lattice->join (TileSizeCandidates ()));
@@ -280,9 +262,12 @@ class TileSizeForwardAnalysis
280262 LogicalResult visitOperation (Operation *op,
281263 ArrayRef<const TileSizeLattice *> operands,
282264 ArrayRef<TileSizeLattice *> results) override {
283- // to_layout: don't propagate operand forward (anchor boundary).
284- // Seeding is done in initialize().
285- if (isa<ToLayoutOp>(op)) {
265+ // to_layout: seed from layout, don't propagate operand forward.
266+ if (auto toLayout = dyn_cast<ToLayoutOp>(op)) {
267+ LDBG () << " Anchor: " << toLayout;
268+ auto candidates = TileSizeCandidates::fromSizes (
269+ toLayout.getLayout ().getUndistributedShape ());
270+ propagateIfChanged (results[0 ], results[0 ]->join (candidates));
286271 return success ();
287272 }
288273
0 commit comments