1616#include " mlir/Dialect/Tensor/IR/Tensor.h"
1717#include " mlir/IR/SymbolTable.h"
1818
19- #define DEBUG_TYPE " iree-codegen-vector-tile-size-analysis "
19+ #define DEBUG_TYPE " iree-codegen-materialize- vector-tile-sizes "
2020
2121// The purpose of this analysis is to propagate information about the
2222// vector tile size across the operation graph. The vector tile size is
@@ -83,11 +83,12 @@ class TileSizes {
8383
8484 unsigned rank () const { return dims.size (); }
8585 bool empty () const { return dims.empty (); }
86- const llvm::SmallVector <int64_t > & getDims () const { return dims; }
86+ const ArrayRef <int64_t > getDims () const { return dims; }
8787
8888 int64_t operator [](unsigned i) const { return dims[i]; }
8989
90- // / Returns true if all dimensions have a defined (positive) tile size.
90+ // / Returns true if the tile sizes are non-empty and every dimension has a
91+ // / concrete tile size (not uninitialized or overdefined).
9192 bool isDefined () const {
9293 return !empty () && llvm::all_of (dims, [](int64_t v) {
9394 return v != kUninitialized && v != kOverdefined ;
@@ -236,12 +237,11 @@ class TileSizeLattice : public dataflow::Lattice<TileSizes> {
236237
237238// / Read the TileSizes from a lattice, returning empty tile sizes if the lattice
238239// / value is from a duplicatable operation.
239- static const TileSizes getTileSizesFor (Value val,
240- const TileSizeLattice *lattice) {
240+ static TileSizes getTileSizesFor (Value val, const TileSizeLattice *lattice) {
241241 if (!lattice) {
242242 return {};
243243 }
244- auto &tileSizes = lattice->getValue ();
244+ const TileSizes &tileSizes = lattice->getValue ();
245245 if (tileSizes.empty ()) {
246246 return {};
247247 }
@@ -280,22 +280,20 @@ class TileSizeForwardAnalysis
280280 unsigned numLoops = linalgOp.getNumLoops ();
281281 TileSizes iterTileSizes (numLoops);
282282 for (OpOperand &operand : linalgOp->getOpOperands ()) {
283- auto &ts = getTileSizesFor (operand. get (),
284- operands[operand.getOperandNumber ()]);
285- if (ts .empty ()) {
283+ TileSizes tileSizes = getTileSizesFor (
284+ operand. get (), operands[operand.getOperandNumber ()]);
285+ if (tileSizes .empty ()) {
286286 continue ;
287287 }
288288 AffineMap map = linalgOp.getMatchingIndexingMap (&operand);
289289 assert (map.getNumDims () == numLoops);
290- iterTileSizes.merge (ts .mapToIterationSpace (map));
290+ iterTileSizes.merge (tileSizes .mapToIterationSpace (map));
291291 }
292292 for (unsigned i = 0 ; i < linalgOp.getNumDpsInits (); ++i) {
293293 OpOperand *init = linalgOp.getDpsInitOperand (i);
294294 AffineMap map = linalgOp.getMatchingIndexingMap (init);
295- auto resultTileSizes = iterTileSizes.mapFromIterationSpace (map);
296- if (!resultTileSizes.empty ()) {
297- propagateIfChanged (results[i], results[i]->join (resultTileSizes));
298- }
295+ TileSizes resultTileSizes = iterTileSizes.mapFromIterationSpace (map);
296+ propagateIfChanged (results[i], results[i]->join (resultTileSizes));
299297 }
300298 return success ();
301299 }
@@ -322,46 +320,44 @@ class TileSizeBackwardAnalysis
322320 // to_layout is always an anchor op; Propagate tile sizes backward to the
323321 // input.
324322 if (auto toLayout = dyn_cast<ToLayoutOp>(op)) {
325- auto &ts = getTileSizesFor (toLayout.getResult (), results[0 ]);
326- if (!ts.empty ()) {
327- TileSizeLattice *inputLattice = operands[0 ];
328- propagateIfChanged (inputLattice, inputLattice->meet (ts));
329- }
323+ TileSizes tileSizes = getTileSizesFor (toLayout.getResult (), results[0 ]);
324+ TileSizeLattice *inputLattice = operands[0 ];
325+ propagateIfChanged (inputLattice, inputLattice->meet (tileSizes));
330326 return success ();
331327 }
332328
333329 // Linalg ops: propagate through indexing maps.
334330 if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
335331 unsigned numLoops = linalgOp.getNumLoops ();
336332 TileSizes iterTileSizes (numLoops);
337- // Gather result tile sizes into iteration space via DPS init maps.
333+ // Gather result tile sizes into iteration space via init maps.
338334 for (auto [result, resultLattice] :
339- llvm::zip (linalgOp. getOperation () ->getResults (), results)) {
340- auto &ts = getTileSizesFor (result, resultLattice);
341- if (ts .empty ()) {
335+ llvm::zip_equal (linalgOp->getResults (), results)) {
336+ TileSizes tileSizes = getTileSizesFor (result, resultLattice);
337+ if (tileSizes .empty ()) {
342338 continue ;
343339 }
344340 unsigned resultIdx = cast<OpResult>(result).getResultNumber ();
345341 OpOperand *init = linalgOp.getDpsInitOperand (resultIdx);
346342 AffineMap map = linalgOp.getMatchingIndexingMap (init);
347343 assert (map.getNumDims () == numLoops);
348- iterTileSizes.merge (ts .mapToIterationSpace (map));
344+ iterTileSizes.merge (tileSizes .mapToIterationSpace (map));
349345 }
350346 // Gather operand tile sizes into iteration space.
351347 for (OpOperand &operand : linalgOp->getOpOperands ()) {
352- auto &ts = getTileSizesFor (operand. get (),
353- operands[operand.getOperandNumber ()]);
354- if (ts .empty ()) {
348+ TileSizes tileSizes = getTileSizesFor (
349+ operand. get (), operands[operand.getOperandNumber ()]);
350+ if (tileSizes .empty ()) {
355351 continue ;
356352 }
357353 AffineMap map = linalgOp.getMatchingIndexingMap (&operand);
358354 assert (map.getNumDims () == numLoops);
359- iterTileSizes.merge (ts .mapToIterationSpace (map));
355+ iterTileSizes.merge (tileSizes .mapToIterationSpace (map));
360356 }
361357 // Map iteration space tile sizes back to each operand.
362358 for (OpOperand &operand : linalgOp->getOpOperands ()) {
363359 AffineMap map = linalgOp.getMatchingIndexingMap (&operand);
364- auto operandTileSizes = iterTileSizes.mapFromIterationSpace (map);
360+ TileSizes operandTileSizes = iterTileSizes.mapFromIterationSpace (map);
365361 if (operandTileSizes.empty ()) {
366362 continue ;
367363 }
@@ -398,31 +394,18 @@ static TileSizes getIterationSpaceTileSizes(linalg::LinalgOp linalgOp,
398394 TileSizes iterTileSizes (numLoops);
399395 for (OpOperand &operand : linalgOp->getOpOperands ()) {
400396 Value val = operand.get ();
401- auto *lattice = solver.lookupState <TileSizeLattice>(val);
402- auto &ts = getTileSizesFor (val, lattice);
403- if (ts .empty ()) {
397+ const TileSizeLattice *lattice = solver.lookupState <TileSizeLattice>(val);
398+ TileSizes tileSize = getTileSizesFor (val, lattice);
399+ if (tileSize .empty ()) {
404400 continue ;
405401 }
406402 AffineMap map = linalgOp.getMatchingIndexingMap (&operand);
407403 assert (map.getNumDims () == numLoops);
408- iterTileSizes.merge (ts .mapToIterationSpace (map));
404+ iterTileSizes.merge (tileSize .mapToIterationSpace (map));
409405 }
410406 return iterTileSizes;
411407}
412408
413- // / Given a linalg op and the solver, compute per-dimension tile sizes.
414- // / Returns a vector of one tile size per iteration dimension, or nullopt if
415- // / any dimension is uninitialized or overdefined.
416- static std::optional<SmallVector<int64_t >>
417- getPerDimTileSizes (linalg::LinalgOp linalgOp, const DataFlowSolver &solver) {
418- TileSizes tileSizes = getIterationSpaceTileSizes (linalgOp, solver);
419- if (!tileSizes.isDefined ()) {
420- return std::nullopt ;
421- }
422- assert (tileSizes.rank () == linalgOp.getNumLoops ());
423- return tileSizes.getDims ();
424- }
425-
426409// ===----------------------------------------------------------------------===//
427410// MaterializeVectorTileSizesPass
428411// ===----------------------------------------------------------------------===//
@@ -437,7 +420,7 @@ class MaterializeVectorTileSizesPass final
437420 MaterializeVectorTileSizesPass> {
438421public:
439422 void runOnOperation () override {
440- auto funcOp = getOperation ();
423+ FunctionOpInterface funcOp = getOperation ();
441424
442425 DataFlowSolver solver;
443426 dataflow::loadBaselineAnalyses (solver);
@@ -450,17 +433,21 @@ class MaterializeVectorTileSizesPass final
450433 }
451434
452435 funcOp->walk ([&](linalg::LinalgOp linalgOp) {
453- std::optional<SmallVector<int64_t >> perDimSizes =
454- getPerDimTileSizes (linalgOp, solver);
455- if (!perDimSizes) {
456- LDBG () << " Analysis did not determine tile size for" << *linalgOp;
436+ TileSizes tileSizes = getIterationSpaceTileSizes (linalgOp, solver);
437+ if (tileSizes.isOverdefined ()) {
438+ linalgOp.emitOpError ()
439+ << " tile size analysis did not determine a valid tile size" ;
440+ return ;
441+ }
442+ if (!tileSizes.isDefined ()) {
443+ LDBG () << " Analysis did not determine tile size for " << *linalgOp;
457444 return ;
458445 }
446+ assert (tileSizes.rank () == linalgOp.getNumLoops ());
459447
460- LDBG () << " Materializing tile size on " << *linalgOp;
461448 linalgOp->setAttr (
462449 kVectorTileSizesAttrName ,
463- DenseI64ArrayAttr::get (linalgOp->getContext (), *perDimSizes ));
450+ DenseI64ArrayAttr::get (linalgOp->getContext (), tileSizes. getDims () ));
464451 });
465452 }
466453};
0 commit comments