Skip to content

Commit 93e18b5

Browse files
committed
Address more PR feedback
Signed-off-by: Lukas Sommer <lukas.sommer@amd.com>
1 parent 8e9a9aa commit 93e18b5

2 files changed

Lines changed: 41 additions & 54 deletions

File tree

compiler/src/iree/compiler/Codegen/Common/MaterializeVectorTileSizes.cpp

Lines changed: 40 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1717
#include "mlir/IR/SymbolTable.h"
1818

19-
#define DEBUG_TYPE "iree-codegen-vector-tile-size-analysis"
19+
#define DEBUG_TYPE "iree-codegen-materialize-vector-tile-sizes"
2020

2121
// The purpose of this analysis is to propagate information about the
2222
// vector tile size across the operation graph. The vector tile size is
@@ -83,11 +83,12 @@ class TileSizes {
8383

8484
unsigned rank() const { return dims.size(); }
8585
bool empty() const { return dims.empty(); }
86-
const llvm::SmallVector<int64_t> &getDims() const { return dims; }
86+
const ArrayRef<int64_t> getDims() const { return dims; }
8787

8888
int64_t operator[](unsigned i) const { return dims[i]; }
8989

90-
/// Returns true if all dimensions have a defined (positive) tile size.
90+
/// Returns true if the tile sizes are non-empty and every dimension has a
91+
/// concrete tile size (not uninitialized or overdefined).
9192
bool isDefined() const {
9293
return !empty() && llvm::all_of(dims, [](int64_t v) {
9394
return v != kUninitialized && v != kOverdefined;
@@ -236,12 +237,11 @@ class TileSizeLattice : public dataflow::Lattice<TileSizes> {
236237

237238
/// Read the TileSizes from a lattice, returning empty tile sizes if the lattice
238239
/// value is from a duplicatable operation.
239-
static const TileSizes getTileSizesFor(Value val,
240-
const TileSizeLattice *lattice) {
240+
static TileSizes getTileSizesFor(Value val, const TileSizeLattice *lattice) {
241241
if (!lattice) {
242242
return {};
243243
}
244-
auto &tileSizes = lattice->getValue();
244+
const TileSizes &tileSizes = lattice->getValue();
245245
if (tileSizes.empty()) {
246246
return {};
247247
}
@@ -280,22 +280,20 @@ class TileSizeForwardAnalysis
280280
unsigned numLoops = linalgOp.getNumLoops();
281281
TileSizes iterTileSizes(numLoops);
282282
for (OpOperand &operand : linalgOp->getOpOperands()) {
283-
auto &ts = getTileSizesFor(operand.get(),
284-
operands[operand.getOperandNumber()]);
285-
if (ts.empty()) {
283+
TileSizes tileSizes = getTileSizesFor(
284+
operand.get(), operands[operand.getOperandNumber()]);
285+
if (tileSizes.empty()) {
286286
continue;
287287
}
288288
AffineMap map = linalgOp.getMatchingIndexingMap(&operand);
289289
assert(map.getNumDims() == numLoops);
290-
iterTileSizes.merge(ts.mapToIterationSpace(map));
290+
iterTileSizes.merge(tileSizes.mapToIterationSpace(map));
291291
}
292292
for (unsigned i = 0; i < linalgOp.getNumDpsInits(); ++i) {
293293
OpOperand *init = linalgOp.getDpsInitOperand(i);
294294
AffineMap map = linalgOp.getMatchingIndexingMap(init);
295-
auto resultTileSizes = iterTileSizes.mapFromIterationSpace(map);
296-
if (!resultTileSizes.empty()) {
297-
propagateIfChanged(results[i], results[i]->join(resultTileSizes));
298-
}
295+
TileSizes resultTileSizes = iterTileSizes.mapFromIterationSpace(map);
296+
propagateIfChanged(results[i], results[i]->join(resultTileSizes));
299297
}
300298
return success();
301299
}
@@ -322,46 +320,44 @@ class TileSizeBackwardAnalysis
322320
// to_layout is always an anchor op; Propagate tile sizes backward to the
323321
// input.
324322
if (auto toLayout = dyn_cast<ToLayoutOp>(op)) {
325-
auto &ts = getTileSizesFor(toLayout.getResult(), results[0]);
326-
if (!ts.empty()) {
327-
TileSizeLattice *inputLattice = operands[0];
328-
propagateIfChanged(inputLattice, inputLattice->meet(ts));
329-
}
323+
TileSizes tileSizes = getTileSizesFor(toLayout.getResult(), results[0]);
324+
TileSizeLattice *inputLattice = operands[0];
325+
propagateIfChanged(inputLattice, inputLattice->meet(tileSizes));
330326
return success();
331327
}
332328

333329
// Linalg ops: propagate through indexing maps.
334330
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
335331
unsigned numLoops = linalgOp.getNumLoops();
336332
TileSizes iterTileSizes(numLoops);
337-
// Gather result tile sizes into iteration space via DPS init maps.
333+
// Gather result tile sizes into iteration space via init maps.
338334
for (auto [result, resultLattice] :
339-
llvm::zip(linalgOp.getOperation()->getResults(), results)) {
340-
auto &ts = getTileSizesFor(result, resultLattice);
341-
if (ts.empty()) {
335+
llvm::zip_equal(linalgOp->getResults(), results)) {
336+
TileSizes tileSizes = getTileSizesFor(result, resultLattice);
337+
if (tileSizes.empty()) {
342338
continue;
343339
}
344340
unsigned resultIdx = cast<OpResult>(result).getResultNumber();
345341
OpOperand *init = linalgOp.getDpsInitOperand(resultIdx);
346342
AffineMap map = linalgOp.getMatchingIndexingMap(init);
347343
assert(map.getNumDims() == numLoops);
348-
iterTileSizes.merge(ts.mapToIterationSpace(map));
344+
iterTileSizes.merge(tileSizes.mapToIterationSpace(map));
349345
}
350346
// Gather operand tile sizes into iteration space.
351347
for (OpOperand &operand : linalgOp->getOpOperands()) {
352-
auto &ts = getTileSizesFor(operand.get(),
353-
operands[operand.getOperandNumber()]);
354-
if (ts.empty()) {
348+
TileSizes tileSizes = getTileSizesFor(
349+
operand.get(), operands[operand.getOperandNumber()]);
350+
if (tileSizes.empty()) {
355351
continue;
356352
}
357353
AffineMap map = linalgOp.getMatchingIndexingMap(&operand);
358354
assert(map.getNumDims() == numLoops);
359-
iterTileSizes.merge(ts.mapToIterationSpace(map));
355+
iterTileSizes.merge(tileSizes.mapToIterationSpace(map));
360356
}
361357
// Map iteration space tile sizes back to each operand.
362358
for (OpOperand &operand : linalgOp->getOpOperands()) {
363359
AffineMap map = linalgOp.getMatchingIndexingMap(&operand);
364-
auto operandTileSizes = iterTileSizes.mapFromIterationSpace(map);
360+
TileSizes operandTileSizes = iterTileSizes.mapFromIterationSpace(map);
365361
if (operandTileSizes.empty()) {
366362
continue;
367363
}
@@ -398,31 +394,18 @@ static TileSizes getIterationSpaceTileSizes(linalg::LinalgOp linalgOp,
398394
TileSizes iterTileSizes(numLoops);
399395
for (OpOperand &operand : linalgOp->getOpOperands()) {
400396
Value val = operand.get();
401-
auto *lattice = solver.lookupState<TileSizeLattice>(val);
402-
auto &ts = getTileSizesFor(val, lattice);
403-
if (ts.empty()) {
397+
const TileSizeLattice *lattice = solver.lookupState<TileSizeLattice>(val);
398+
TileSizes tileSize = getTileSizesFor(val, lattice);
399+
if (tileSize.empty()) {
404400
continue;
405401
}
406402
AffineMap map = linalgOp.getMatchingIndexingMap(&operand);
407403
assert(map.getNumDims() == numLoops);
408-
iterTileSizes.merge(ts.mapToIterationSpace(map));
404+
iterTileSizes.merge(tileSize.mapToIterationSpace(map));
409405
}
410406
return iterTileSizes;
411407
}
412408

413-
/// Given a linalg op and the solver, compute per-dimension tile sizes.
414-
/// Returns a vector of one tile size per iteration dimension, or nullopt if
415-
/// any dimension is uninitialized or overdefined.
416-
static std::optional<SmallVector<int64_t>>
417-
getPerDimTileSizes(linalg::LinalgOp linalgOp, const DataFlowSolver &solver) {
418-
TileSizes tileSizes = getIterationSpaceTileSizes(linalgOp, solver);
419-
if (!tileSizes.isDefined()) {
420-
return std::nullopt;
421-
}
422-
assert(tileSizes.rank() == linalgOp.getNumLoops());
423-
return tileSizes.getDims();
424-
}
425-
426409
//===----------------------------------------------------------------------===//
427410
// MaterializeVectorTileSizesPass
428411
//===----------------------------------------------------------------------===//
@@ -437,7 +420,7 @@ class MaterializeVectorTileSizesPass final
437420
MaterializeVectorTileSizesPass> {
438421
public:
439422
void runOnOperation() override {
440-
auto funcOp = getOperation();
423+
FunctionOpInterface funcOp = getOperation();
441424

442425
DataFlowSolver solver;
443426
dataflow::loadBaselineAnalyses(solver);
@@ -450,17 +433,21 @@ class MaterializeVectorTileSizesPass final
450433
}
451434

452435
funcOp->walk([&](linalg::LinalgOp linalgOp) {
453-
std::optional<SmallVector<int64_t>> perDimSizes =
454-
getPerDimTileSizes(linalgOp, solver);
455-
if (!perDimSizes) {
456-
LDBG() << "Analysis did not determine tile size for" << *linalgOp;
436+
TileSizes tileSizes = getIterationSpaceTileSizes(linalgOp, solver);
437+
if (tileSizes.isOverdefined()) {
438+
linalgOp.emitOpError()
439+
<< "tile size analysis did not determine a valid tile size";
440+
return;
441+
}
442+
if (!tileSizes.isDefined()) {
443+
LDBG() << "Analysis did not determine tile size for " << *linalgOp;
457444
return;
458445
}
446+
assert(tileSizes.rank() == linalgOp.getNumLoops());
459447

460-
LDBG() << "Materializing tile size on " << *linalgOp;
461448
linalgOp->setAttr(
462449
kVectorTileSizesAttrName,
463-
DenseI64ArrayAttr::get(linalgOp->getContext(), *perDimSizes));
450+
DenseI64ArrayAttr::get(linalgOp->getContext(), tileSizes.getDims()));
464451
});
465452
}
466453
};

compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: iree-opt --pass-pipeline='builtin.module(any(iree-codegen-materialize-vector-tile-sizes))' --split-input-file %s | FileCheck %s
1+
// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-codegen-materialize-vector-tile-sizes))' --split-input-file %s | FileCheck %s
22

33
// Elementwise chain from to_layout anchor.
44

0 commit comments

Comments
 (0)