From 1376f48d65fb3b606031572ee9f96440c441287e Mon Sep 17 00:00:00 2001 From: Lukas Sommer Date: Fri, 27 Feb 2026 15:12:50 +0000 Subject: [PATCH 1/8] [VectorDistribute] Add VectorTileSizeAnalysis Signed-off-by: Lukas Sommer --- .../iree/compiler/Codegen/Common/BUILD.bazel | 1 + .../compiler/Codegen/Common/CMakeLists.txt | 1 + .../Codegen/Common/GenericVectorization.cpp | 28 +- .../iree/compiler/Codegen/Common/Passes.td | 10 + .../Codegen/Common/VectorTileSizeAnalysis.cpp | 461 ++++++++++++++++++ .../compiler/Codegen/Common/test/BUILD.bazel | 1 + .../Codegen/Common/test/CMakeLists.txt | 1 + .../test/materialize_vector_tile_sizes.mlir | 196 ++++++++ .../Dialect/Codegen/IR/IREECodegenAttrs.h | 2 + .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 3 + 10 files changed, 702 insertions(+), 2 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp create mode 100644 compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index 233f2760bf82..6aabd195093e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -174,6 +174,7 @@ iree_compiler_cc_library( "UnrollAnnotatedLoops.cpp", "UserConfig.cpp", "VectorLayoutAnalysis.cpp", + "VectorTileSizeAnalysis.cpp", "VectorTransferLowering.cpp", "VectorizeMemrefCopy.cpp", "VerifyPipelineConstraints.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index 4535db2e6fde..f3ca75204682 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -167,6 +167,7 @@ iree_cc_library( "UnrollAnnotatedLoops.cpp" "UserConfig.cpp" "VectorLayoutAnalysis.cpp" + "VectorTileSizeAnalysis.cpp" "VectorTransferLowering.cpp" "VectorizeMemrefCopy.cpp" "VerifyPipelineConstraints.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp index e539bacf0354..be265a29db97 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp @@ -30,8 +30,10 @@ namespace mlir::iree_compiler { #include "iree/compiler/Codegen/Common/Passes.h.inc" namespace { -// Returns the vector sizes from the local lowering config or try to infer them -// from the tensor shapes and tiled loops in the IR. + +// Returns the vector sizes from the local lowering config, materialized +// tile size attributes, or tries to infer them from the tensor shapes and +// tiled loops in the IR. static std::optional getVectorSizes(Operation *op, bool useConfiguredVectorSizes) { // Get vector sizes from the lowering config, if available in the op itself. @@ -80,6 +82,28 @@ getVectorSizes(Operation *op, bool useConfiguredVectorSizes) { LDBG() << "Failed to get configured vector sizes, fall back to inference"; } + // Try to get vector sizes from materialized tile size attribute. + // The attribute is an array of per-dimension candidate lists; use the + // maximum from each dimension. + if (auto tileSizesAttr = + op->getAttrOfType(kVectorTileSizesAttrName)) { + SmallVector vectorSizes; + bool valid = !tileSizesAttr.empty(); + for (auto dimAttr : tileSizesAttr) { + auto dimSizes = cast(dimAttr); + if (dimSizes.empty()) { + valid = false; + break; + } + vectorSizes.push_back(*llvm::max_element(dimSizes.asArrayRef())); + } + if (valid) { + LDBG() << "Use vector sizes from materialized tile size attribute"; + SmallVector scalableFlags(vectorSizes.size(), false); + return std::make_pair(vectorSizes, scalableFlags); + } + } + // Try to infer the vector sizes from the IR. std::optional> vectorSizes; SmallVector scalableFlags; diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index 64c81e478cc2..60890443b962 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -822,6 +822,16 @@ def MaterializeEncodingIntoNopPass : let summary = "Drop the encodings from tensor types with encodings."; } +def MaterializeVectorTileSizesPass : + InterfacePass<"iree-codegen-materialize-vector-tile-sizes", + "mlir::FunctionOpInterface"> { + let summary = "Propagate vector tile sizes and materialize as attribute"; + let description = [{ + Propagate vector tile sizes from to_layout anchors and materialize them as + discardable attributes on compute ops. + }]; +} + def MaterializeTuningSpecsPass : Pass<"iree-codegen-materialize-tuning-specs", "ModuleOp"> { let summary = "Load tuning spec transform dialect libraries and encode them in the module"; diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp new file mode 100644 index 000000000000..9b62d46bb5dd --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp @@ -0,0 +1,461 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h" + +#include "llvm/Support/DebugLog.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "llvm/ADT/SmallSet.h" + +#define DEBUG_TYPE "iree-codegen-vector-tile-size-analysis" + +namespace mlir::iree_compiler { + +using namespace IREE::VectorExt; + +using TileSizeSet = llvm::SmallSet; + +/// Per-dimension tile size candidates. Each dimension has an independent set +/// of candidate tile sizes. +class TileSizeCandidates { +public: + TileSizeCandidates() = default; + explicit TileSizeCandidates(unsigned rank) : dims(rank) {} + + unsigned rank() const { return dims.size(); } + bool empty() const { return dims.empty(); } + + const TileSizeSet &operator[](unsigned i) const { return dims[i]; } + TileSizeSet &operator[](unsigned i) { return dims[i]; } + + /// Merge candidates from `other` into this. Returns true if anything changed. + bool merge(const TileSizeCandidates &other) { + assert(rank() == other.rank() && "rank mismatch"); + bool changed = false; + for (unsigned i = 0; i < rank(); ++i) { + for (int64_t v : other.dims[i]) { + changed |= dims[i].insert(v).second; + } + } + return changed; + } + + /// Merge a single concrete tile size (one value per dimension). + /// Values of -1 (unknown) are skipped. If this object is uninitialized + /// (rank 0), it is initialized from the size of `concreteSizes`. + bool merge(ArrayRef concreteSizes) { + if (empty()) { + dims.resize(concreteSizes.size()); + } + assert(rank() == concreteSizes.size() && "rank mismatch"); + bool changed = false; + for (unsigned i = 0; i < rank(); ++i) { + if (concreteSizes[i] != -1) { + changed |= dims[i].insert(concreteSizes[i]).second; + } + } + return changed; + } + + /// Returns true if any dimension has more than one candidate. + bool hasAlternatives() const { + return llvm::any_of(dims, + [](const TileSizeSet &s) { return s.size() > 1; }); + } + + /// Map from operand space to iteration space via an indexing map. + TileSizeCandidates mapToIterationSpace(AffineMap indexingMap, + unsigned numLoops) const { + TileSizeCandidates result(numLoops); + for (unsigned i = 0; i < indexingMap.getNumResults(); ++i) { + auto dimExpr = dyn_cast(indexingMap.getResult(i)); + if (!dimExpr) { + continue; + } + unsigned iterDim = dimExpr.getPosition(); + for (int64_t v : dims[i]) { + result.dims[iterDim].insert(v); + } + } + return result; + } + + /// Map from iteration space to operand space via an indexing map. + /// Returns empty TileSizeCandidates if any operand dim can't be determined. + TileSizeCandidates mapFromIterationSpace(AffineMap indexingMap) const { + unsigned numResults = indexingMap.getNumResults(); + TileSizeCandidates result(numResults); + for (unsigned i = 0; i < numResults; ++i) { + auto dimExpr = dyn_cast(indexingMap.getResult(i)); + if (!dimExpr) { + return {}; + } + unsigned iterDim = dimExpr.getPosition(); + if (iterDim >= rank() || dims[iterDim].empty()) { + return {}; + } + result.dims[i] = dims[iterDim]; + } + return result; + } + +private: + SmallVector dims; +}; + +/// Returns true if the operation is trivially duplicatable and should not +/// propagate merged tile sizes across independent consumers. +static bool isDuplicatable(Value val) { + Operation *defOp = val.getDefiningOp(); + if (!defOp) { + return false; + } + if (isa(defOp)) { + return true; + } + if (defOp->hasTrait()) { + return true; + } + // Catches linalg.fill that has been lowered/fused into linalg.generic form + // (scalar input broadcast into tensor.empty output). + if (auto genericOp = dyn_cast(defOp)) { + if (genericOp.getNumDpsInputs() == 1 && genericOp.getNumDpsInits() == 1 && + !isa(genericOp.getDpsInputs()[0].getType())) { + Value init = genericOp.getDpsInits()[0]; + if (init.getDefiningOp()) { + return true; + } + } + } + if (auto fillOp = dyn_cast(defOp)) { + if (fillOp.getOutputs()[0].getDefiningOp()) { + return true; + } + } + return false; +} + +struct TileSizeState { + void propagateForward(Value val); + void propagateBackward(Value val); + + /// Merge candidates into a value and enqueue if anything changed. + void mergeAndEnqueue(Value val, const TileSizeCandidates &candidates) { + if (!isa(val.getType())) { + return; + } + if (candidates.empty()) { + return; + } + // If val is not yet in the map, inserting it may rehash the DenseMap + // and invalidate `candidates` if it aliases an existing entry. Copy + // directly into the new entry to avoid the dangling reference. + if (!tileSizes.count(val)) { + tileSizes[val] = candidates; + } else { + if (!tileSizes[val].merge(candidates)) { + return; + } + } + // We don't forward multiple alternatives from operations that are easy to + // duplicate. CSE will deduplicate DPS init operands, creating edges between + // unrelated compute operations. Propagating different vector tile sizes via + // shared DPS inits doesn't provide any value in that case. + if (isDuplicatable(val) && tileSizes[val].hasAlternatives()) { + return; + } + // Propagate the update. + forward.push(val); + backward.push(val); + } + + /// Convenience: merge a single concrete tile size and enqueue if changed. + void mergeAndEnqueue(Value val, ArrayRef concreteSizes) { + TileSizeCandidates candidates(concreteSizes.size()); + candidates.merge(concreteSizes); + mergeAndEnqueue(val, candidates); + } + + bool hasTileSize(Value val) const { return tileSizes.count(val); } + + const TileSizeCandidates &getCandidates(Value val) const { + static const TileSizeCandidates empty; + auto it = tileSizes.find(val); + if (it == tileSizes.end()) { + return empty; + } + return it->second; + } + + /// Propagate through a linalg.generic: given known tile sizes on some + /// operands, infer tile sizes for other operands via indexing maps. + void propagateGenericOp(linalg::GenericOp genericOp); + + DenseMap tileSizes; + std::queue forward; + std::queue backward; +}; + +/// Collect per-dimension tile size candidate sets from a linalg op's operands. +/// Returns a TileSizeCandidates of size numLoops, where each dimension is the +/// union of all candidate tile sizes for that iteration dimension across all +/// operands. +static TileSizeCandidates +getIterationSpaceTileSizes(linalg::LinalgOp linalgOp, + const TileSizeState &state) { + unsigned numLoops = linalgOp.getNumLoops(); + TileSizeCandidates result(numLoops); + for (OpOperand &operand : linalgOp->getOpOperands()) { + auto &candidates = state.getCandidates(operand.get()); + if (candidates.empty()) { + continue; + } + AffineMap map = linalgOp.getMatchingIndexingMap(&operand); + auto mapped = candidates.mapToIterationSpace(map, numLoops); + result.merge(mapped); + } + return result; +} + +void TileSizeState::propagateGenericOp(linalg::GenericOp genericOp) { + auto perDimSizes = getIterationSpaceTileSizes(genericOp, *this); + + // Map per-dimension iteration-space candidates to each operand's dimensions + // via its indexing map. + for (OpOperand &operand : genericOp->getOpOperands()) { + AffineMap map = genericOp.getMatchingIndexingMap(&operand); + auto operandCandidates = perDimSizes.mapFromIterationSpace(map); + if (operandCandidates.empty()) { + continue; + } + mergeAndEnqueue(operand.get(), operandCandidates); + } + + // Propagate to results via their corresponding init operands. + for (auto [init, result] : + llvm::zip_equal(genericOp.getDpsInits(), genericOp.getResults())) { + mergeAndEnqueue(result, getCandidates(init)); + } +} + +void TileSizeState::propagateForward(Value val) { + auto &candidates = getCandidates(val); + if (candidates.empty()) { + return; + } + LDBG() << "Propagating tile size forward for: " << val; + + for (OpOperand &use : val.getUses()) { + Operation *user = use.getOwner(); + unsigned operandIdx = use.getOperandNumber(); + + // scf.for: propagate to tied loop body arg and result. + if (auto forOp = dyn_cast(user)) { + Value arg = forOp.getTiedLoopRegionIterArg(&use); + Value result = forOp.getTiedLoopResult(&use); + mergeAndEnqueue(arg, candidates); + mergeAndEnqueue(result, candidates); + continue; + } + + // scf.yield: propagate to parent op's results/args. + if (auto yieldOp = dyn_cast(user)) { + Operation *parentOp = yieldOp->getParentOp(); + if (auto forOp = dyn_cast(parentOp)) { + Value arg = forOp.getRegionIterArg(operandIdx); + Value result = forOp->getResult(operandIdx); + mergeAndEnqueue(arg, candidates); + mergeAndEnqueue(result, candidates); + continue; + } + if (auto ifOp = dyn_cast(parentOp)) { + Value result = ifOp->getResult(operandIdx); + mergeAndEnqueue(result, candidates); + continue; + } + } + + // Elementwise ops: propagate to all results. + if (OpTrait::hasElementwiseMappableTraits(user)) { + for (OpResult result : user->getOpResults()) { + mergeAndEnqueue(result, candidates); + } + continue; + } + + // linalg.generic: propagate through indexing maps. + if (auto genericOp = dyn_cast(user)) { + propagateGenericOp(genericOp); + continue; + } + } +} + +void TileSizeState::propagateBackward(Value val) { + LDBG() << "Propagating tile size backward for: " << val; + auto &candidates = getCandidates(val); + if (candidates.empty()) { + return; + } + + // Block arguments (e.g., scf.for iter_args). + if (auto blockArg = dyn_cast(val)) { + Operation *parent = val.getParentBlock()->getParentOp(); + if (auto forOp = dyn_cast(parent)) { + OpOperand *yielded = forOp.getTiedLoopYieldedValue(blockArg); + OpOperand *init = forOp.getTiedLoopInit(blockArg); + if (yielded) { + mergeAndEnqueue(yielded->get(), candidates); + } + if (init) { + mergeAndEnqueue(init->get(), candidates); + } + } + return; + } + + Operation *defOp = val.getDefiningOp(); + if (!defOp) { + return; + } + + // Elementwise ops: propagate to all operands. + if (OpTrait::hasElementwiseMappableTraits(defOp)) { + for (OpOperand &operand : defOp->getOpOperands()) { + if (isa(operand.get().getType())) { + mergeAndEnqueue(operand.get(), candidates); + } + } + return; + } + + // linalg.generic: propagate through indexing maps. + if (auto genericOp = dyn_cast(defOp)) { + unsigned resultIdx = cast(val).getResultNumber(); + mergeAndEnqueue(genericOp.getDpsInitOperand(resultIdx)->get(), candidates); + propagateGenericOp(genericOp); + return; + } + + // to_layout: propagate to input. + // We only propagate backward for to_layout, not forward, as to_layout is an + // anchor for initialization itself. + if (auto toLayout = dyn_cast(defOp)) { + mergeAndEnqueue(toLayout.getInput(), candidates); + return; + } + + // scf.for results: propagate to yield and init. + if (auto forOp = dyn_cast(defOp)) { + unsigned resultIdx = cast(val).getResultNumber(); + Value init = forOp.getInits()[resultIdx]; + auto yieldOp = cast(forOp.getBody()->getTerminator()); + mergeAndEnqueue(init, candidates); + mergeAndEnqueue(yieldOp.getOperand(resultIdx), candidates); + return; + } + + // scf.if results: propagate to yields in both regions. + if (auto ifOp = dyn_cast(defOp)) { + unsigned resultIdx = cast(val).getResultNumber(); + auto thenYield = cast(ifOp.thenBlock()->getTerminator()); + mergeAndEnqueue(thenYield.getOperand(resultIdx), candidates); + assert(ifOp.elseBlock() && "scf.if with results must have an else block"); + auto elseYield = cast(ifOp.elseBlock()->getTerminator()); + mergeAndEnqueue(elseYield.getOperand(resultIdx), candidates); + return; + } +} + +/// Run the VectorTileSizeAnalysis on the given root operation. +static void runAnalysis(Operation *root, TileSizeState &state) { + // Initialize from to_layout anchors. + root->walk([&](ToLayoutOp toLayout) { + SmallVector undistShape = + toLayout.getLayout().getUndistributedShape(); + LDBG() << "Anchor: " << toLayout; + state.mergeAndEnqueue(toLayout.getResult(), undistShape); + }); + + // Fixpoint iteration: forward first, then backward. + while (!state.forward.empty() || !state.backward.empty()) { + if (!state.forward.empty()) { + Value val = state.forward.front(); + state.forward.pop(); + state.propagateForward(val); + } else { + Value val = state.backward.front(); + state.backward.pop(); + state.propagateBackward(val); + } + } +} + +/// Given a linalg op and the analysis state, compute per-dimension sets of +/// candidate tile sizes. Returns a vector of size numLoops, where each entry +/// is the deduplicated set of tile sizes for that iteration dimension. +/// Returns an empty vector if any dimension has no candidates. +static SmallVector> +getPerDimTileSizes(linalg::LinalgOp linalgOp, const TileSizeState &state) { + auto perDimSizes = getIterationSpaceTileSizes(linalgOp, state); + + // Return empty if any dimension has no candidates. + SmallVector> results; + for (unsigned i = 0; i < perDimSizes.rank(); ++i) { + if (perDimSizes[i].empty()) { + return {}; + } + results.push_back( + SmallVector(perDimSizes[i].begin(), perDimSizes[i].end())); + } + return results; +} + +//===----------------------------------------------------------------------===// +// MaterializeVectorTileSizesPass +//===----------------------------------------------------------------------===// + +#define GEN_PASS_DEF_MATERIALIZEVECTORTILESIZESPASS +#include "iree/compiler/Codegen/Common/Passes.h.inc" + +namespace { + +class MaterializeVectorTileSizesPass final + : public impl::MaterializeVectorTileSizesPassBase< + MaterializeVectorTileSizesPass> { +public: + void runOnOperation() override { + auto funcOp = getOperation(); + + TileSizeState state; + runAnalysis(funcOp, state); + + funcOp->walk([&](linalg::LinalgOp linalgOp) { + auto perDimSizes = getPerDimTileSizes(linalgOp, state); + if (perDimSizes.empty()) { + return; + } + + LDBG() << "Materializing tile size on " << *linalgOp; + + SmallVector dimAttrs; + for (const auto &dimSizes : perDimSizes) { + dimAttrs.push_back( + DenseI64ArrayAttr::get(linalgOp->getContext(), dimSizes)); + } + linalgOp->setAttr(kVectorTileSizesAttrName, + ArrayAttr::get(linalgOp->getContext(), dimAttrs)); + }); + } +}; + +} // namespace +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel index 83460c922156..7bdf6f727b3c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel @@ -102,6 +102,7 @@ iree_lit_test_suite( "materialize_user_config_from_tuning_spec.mlir", "materialize_user_configs.mlir", "materialize_vector_masking.mlir", + "materialize_vector_tile_sizes.mlir", "math_transform.mlir", "normalize_loop_bounds.mlir", "optimize_tensor_insert_extract_slices.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt index b84a8121a982..5bb574e1f30e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt @@ -97,6 +97,7 @@ iree_lit_test_suite( "materialize_user_config_from_tuning_spec.mlir" "materialize_user_configs.mlir" "materialize_vector_masking.mlir" + "materialize_vector_tile_sizes.mlir" "math_transform.mlir" "normalize_loop_bounds.mlir" "optimize_tensor_insert_extract_slices.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir b/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir new file mode 100644 index 000000000000..3e6f0f3533c3 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir @@ -0,0 +1,196 @@ +// RUN: iree-opt --pass-pipeline='builtin.module(any(iree-codegen-materialize-vector-tile-sizes))' --split-input-file %s | FileCheck %s + +// Elementwise chain from to_layout anchor. + +#layout = #iree_vector_ext.nested_layout< + subgroup_tile = [1], batch_tile = [8], outer_tile = [1], + thread_tile = [1], element_tile = [8], + subgroup_strides = [0], thread_strides = [0]> + +// CHECK-LABEL: @elementwise_from_anchor +func.func @elementwise_from_anchor(%arg0: tensor<63xf16>) -> tensor<63xf16> { + %empty = tensor.empty() : tensor<63xf16> + // CHECK: linalg.generic + // CHECK-SAME: iree_codegen.vector_tile_sizes = [array] + %0 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%arg0 : tensor<63xf16>) outs(%empty : tensor<63xf16>) { + ^bb0(%in: f16, %out: f16): + %add = arith.addf %in, %in : f16 + linalg.yield %add : f16 + } -> tensor<63xf16> + %1 = iree_vector_ext.to_layout %0 to layout(#layout) : tensor<63xf16> + // CHECK: linalg.generic + // CHECK-SAME: iree_codegen.vector_tile_sizes = [array] + %2 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%1 : tensor<63xf16>) outs(%empty : tensor<63xf16>) { + ^bb0(%in: f16, %out: f16): + %mul = arith.mulf %in, %in : f16 + linalg.yield %mul : f16 + } -> tensor<63xf16> + return %2 : tensor<63xf16> +} + +// ----- + +// Chain propagation with transpose: tile sizes must propagate through +// generic A's result to generic B, with B using a transposed indexing map. + +#layout_2d = #iree_vector_ext.nested_layout< + subgroup_tile = [1, 1], batch_tile = [1, 8], outer_tile = [1, 1], + thread_tile = [1, 1], element_tile = [8, 8], + subgroup_strides = [0, 0], thread_strides = [0, 0]> + +// CHECK-LABEL: @chain_propagation_transpose +func.func @chain_propagation_transpose( + %arg0: tensor<8x64xf32>, %arg1: tensor<8x64xf32>) -> tensor<64x8xf32> { + %a = iree_vector_ext.to_layout %arg0 to layout(#layout_2d) : tensor<8x64xf32> + %empty_ab = tensor.empty() : tensor<8x64xf32> + // CHECK: linalg.generic + // CHECK-SAME: iree_codegen.vector_tile_sizes = [array, array] + %ab = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%a, %arg1 : tensor<8x64xf32>, tensor<8x64xf32>) + outs(%empty_ab : tensor<8x64xf32>) { + ^bb0(%in0: f32, %in1: f32, %out: f32): + %add = arith.addf %in0, %in1 : f32 + linalg.yield %add : f32 + } -> tensor<8x64xf32> + %empty_t = tensor.empty() : tensor<64x8xf32> + // CHECK: linalg.generic + // CHECK-SAME: iree_codegen.vector_tile_sizes = [array, array] + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%ab : tensor<8x64xf32>) outs(%empty_t : tensor<64x8xf32>) { + ^bb0(%in: f32, %out: f32): + %neg = arith.negf %in : f32 + linalg.yield %neg : f32 + } -> tensor<64x8xf32> + return %result : tensor<64x8xf32> +} + +// ----- + +// scf.for propagation through iter_args. +// The to_layout inside the loop should propagate tile sizes to the +// loop iter_args and through the scf.yield. + +#layout = #iree_vector_ext.nested_layout< + subgroup_tile = [8], batch_tile = [1], outer_tile = [1], + thread_tile = [64], element_tile = [1], + subgroup_strides = [1], thread_strides = [1]> + +// CHECK-LABEL: @scf_for_propagation +func.func @scf_for_propagation(%arg0: tensor<512xf32>, %lb: index, %ub: index, %step: index) -> tensor<512xf32> { + %empty = tensor.empty() : tensor<512xf32> + %cst = arith.constant 0.0 : f32 + // CHECK: linalg.generic + // CHECK-SAME: iree_codegen.vector_tile_sizes = [array] + %init = linalg.generic { + indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%cst : f32) outs(%empty : tensor<512xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<512xf32> + %result = scf.for %iv = %lb to %ub step %step iter_args(%iter = %init) -> tensor<512xf32> { + %laid_out = iree_vector_ext.to_layout %iter to layout(#layout) : tensor<512xf32> + // CHECK: linalg.generic + // CHECK-SAME: iree_codegen.vector_tile_sizes = [array] + %updated = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%laid_out, %arg0 : tensor<512xf32>, tensor<512xf32>) outs(%empty : tensor<512xf32>) { + ^bb0(%in0: f32, %in1: f32, %out: f32): + %add = arith.addf %in0, %in1 : f32 + linalg.yield %add : f32 + } -> tensor<512xf32> + scf.yield %updated : tensor<512xf32> + } + return %result : tensor<512xf32> +} + +// ----- + +#layout_a = #iree_vector_ext.nested_layout< + subgroup_tile = [1], batch_tile = [8], outer_tile = [1], + thread_tile = [1], element_tile = [8], + subgroup_strides = [0], thread_strides = [0]> + +#layout_b = #iree_vector_ext.nested_layout< + subgroup_tile = [8, 1], batch_tile = [1, 8], outer_tile = [1, 1], + thread_tile = [64, 1], element_tile = [1, 8], + subgroup_strides = [1, 0], thread_strides = [1, 0]> + +#layout_c = #iree_vector_ext.nested_layout< + subgroup_tile = [8], batch_tile = [1], outer_tile = [1], + thread_tile = [64], element_tile = [1], + subgroup_strides = [1], thread_strides = [1]> + +// CHECK-LABEL: @contraction_indexing_maps +func.func @contraction_indexing_maps( + %a: tensor<63xf16>, %b: tensor<512x63xf16>, %c: tensor<512xf32>) -> tensor<512xf32> { + %al = iree_vector_ext.to_layout %a to layout(#layout_a) : tensor<63xf16> + %bl = iree_vector_ext.to_layout %b to layout(#layout_b) : tensor<512x63xf16> + %cl = iree_vector_ext.to_layout %c to layout(#layout_c) : tensor<512xf32> + // CHECK: linalg.generic + // CHECK-SAME: iree_codegen.vector_tile_sizes = [array, array] + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d1)> + ], + iterator_types = ["reduction", "parallel"] + } ins(%al, %bl : tensor<63xf16>, tensor<512x63xf16>) outs(%cl : tensor<512xf32>) { + ^bb0(%in0: f16, %in1: f16, %out: f32): + %ext0 = arith.extf %in0 : f16 to f32 + %ext1 = arith.extf %in1 : f16 to f32 + %mul = arith.mulf %ext0, %ext1 : f32 + %add = arith.addf %mul, %out : f32 + linalg.yield %add : f32 + } -> tensor<512xf32> + return %result : tensor<512xf32> +} + +// ----- + +// scf.if propagation: tile size from the to_layout inside one branch +// should propagate through the scf.if result to consumers outside. + +#layout_if = #iree_vector_ext.nested_layout< + subgroup_tile = [8], batch_tile = [1], outer_tile = [1], + thread_tile = [64], element_tile = [1], + subgroup_strides = [1], thread_strides = [1]> + +// CHECK-LABEL: @scf_if_propagation +func.func @scf_if_propagation(%arg0: tensor<512xf32>, %cond: i1) -> tensor<512xf32> { + %empty = tensor.empty() : tensor<512xf32> + %cst = arith.constant 0.0 : f32 + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<512xf32>) -> tensor<512xf32> + %if_result = scf.if %cond -> tensor<512xf32> { + %laid_out = iree_vector_ext.to_layout %arg0 to layout(#layout_if) : tensor<512xf32> + scf.yield %laid_out : tensor<512xf32> + } else { + scf.yield %fill : tensor<512xf32> + } + // CHECK: linalg.generic + // CHECK-SAME: iree_codegen.vector_tile_sizes = [array] + %result = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%if_result : tensor<512xf32>) outs(%empty : tensor<512xf32>) { + ^bb0(%in: f32, %out: f32): + %neg = arith.negf %in : f32 + linalg.yield %neg : f32 + } -> tensor<512xf32> + return %result : tensor<512xf32> +} diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h index c0f5a4752da2..90ea99440bfb 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h @@ -71,6 +71,8 @@ constexpr StringLiteral kSerializedTuningSpecAttrName = constexpr StringLiteral kKernelConfigSpecName = "__kernel_config"; constexpr StringLiteral kUkernelAttrName = "iree_codegen.ukernel"; constexpr StringLiteral kUKernelProviderName = "iree_codegen.ukernel_provider"; +constexpr StringLiteral kVectorTileSizesAttrName = + "iree_codegen.vector_tile_sizes"; //===----------------------------------------------------------------------===// // Helpers for getting/setting iree_codegen.translation_info attribute on a diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index a8a8c7ddfc3c..325914eb2647 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -276,6 +276,9 @@ static void addGPUVectorizationPasses(OpPassManager &funcPassManager, funcPassManager.addPass(IREE::LinalgExt::createDecomposeIm2colPass()); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); + if (enableMasking) { + funcPassManager.addPass(createMaterializeVectorTileSizesPass()); + } // Vectorize. GenericVectorizationPassOptions options; options.vectorizeCopies = vectorizeCopies; From 5aa488911f48299bec718cb3ea4d9729d6afdc89 Mon Sep 17 00:00:00 2001 From: Lukas Sommer Date: Fri, 6 Mar 2026 14:47:05 +0000 Subject: [PATCH 2/8] Use dataflow framework for vector tile size analysis Signed-off-by: Lukas Sommer --- .../Codegen/Common/VectorTileSizeAnalysis.cpp | 556 ++++++++++-------- 1 file changed, 308 insertions(+), 248 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp index 9b62d46bb5dd..098cfe602a5e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp @@ -9,14 +9,70 @@ #include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h" #include "llvm/Support/DebugLog.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Analysis/DataFlow/Utils.h" +#include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/SymbolTable.h" #include "llvm/ADT/SmallSet.h" #define DEBUG_TYPE "iree-codegen-vector-tile-size-analysis" +// The purpose of this analysis is to propagate information about the +// undistributed vector tile size across the operation graph. The vector tile +// size is important information for the vectorization of operations. +// For example, the vector tile size can be used by GenericVectorization to +// introduce the necessary masking in the presence of padding/masking. +// +// The analysis is a bi-directional dataflow analysis building on top of the +// upstream MLIR dataflow analysis framework. To implement the bi-directional +// propagation, it combines a sparse forward analysis and a sparse backward +// analysis in the same solver. +// +// The lattice for the dataflow analysis is shared by both analyses (forward and +// backward). For each N-dimensional ShapedType SSA value, we have a lattice +// element comprising N sets, where each set contains the candidate tile sizes +// for that dimension. The bottom (uninitialized) state of the lattice is simply +// empty. The join/merge operation for two lattice elements is the per-dimension +// set-union of candidates. For example, in the 2D case: +// ({2, 4}, {16}) U ({8}, {32}) = ({2, 4, 8}, {16, 32}) +// +// As the sets can only grow, the join/meet operator is by definition monotonic. +// As the set union can not result in a conflict, no lattice state for top +// (overdefined) is required in this lattice. +// +// The lattice is initialized from `to_layout` operations. +// +// Forward propagation and backward propagation work similarly: +// - For elementwise operations, candidates from the different operands +// (forward) or results (backwards) are merged. The merged lattice state is +// then propagated to all results (forward) or operands (backward). +// - For linalg.generic operations, all available information from operands +// (forward) or results & operands (backward) is mapped to the iteration space +// based on indexing maps and merged into a single lattice state. That lattice +// state in the iteration space is then mapped to each result (forward) or +// operand (backward) based on indexing maps and the mapped state is +// propagated. +// +// The only exception to this process are duplicatable operations such as +// `tensor.empty`. CSE connects otherwise unrelated compute ops by deduplicating +// their DPS init operands to a single tensor.empty (or similar). To avoid +// cross-polluting the vector tile size of unrelated operations, propagation +// from duplicatable operations is stopped if they contain multiple candidates +// tile sizes in at least one dimension. +// +// For some other operations, no propagation rules are defined on purpose. For +// example, `extract_slice` and `insert_slice` operations are natural boundaries +// of tiling/padding, therefore no information is propagated across them. +// +// After the dataflow solver reaches a fixpoint, the +// MaterializeVectorTileSizesPass materializes the result as a discardable +// attribute. At this point, the result is a set of candidate vector tile sizes +// per iteration dimension. It is up to the users of the analysis how to select +// a tile size from the set of candidates. + namespace mlir::iree_compiler { using namespace IREE::VectorExt; @@ -24,45 +80,41 @@ using namespace IREE::VectorExt; using TileSizeSet = llvm::SmallSet; /// Per-dimension tile size candidates. Each dimension has an independent set -/// of candidate tile sizes. +/// of candidate tile sizes. Satisfies the requirements for use as a +/// `Lattice` value type. class TileSizeCandidates { public: TileSizeCandidates() = default; explicit TileSizeCandidates(unsigned rank) : dims(rank) {} + /// Construct from a single concrete tile size (one value per dimension). + static TileSizeCandidates fromSizes(ArrayRef sizes) { + TileSizeCandidates result(sizes.size()); + for (unsigned i = 0; i < sizes.size(); ++i) { + result.dims[i].insert(sizes[i]); + } + return result; + } + unsigned rank() const { return dims.size(); } bool empty() const { return dims.empty(); } const TileSizeSet &operator[](unsigned i) const { return dims[i]; } TileSizeSet &operator[](unsigned i) { return dims[i]; } - /// Merge candidates from `other` into this. Returns true if anything changed. - bool merge(const TileSizeCandidates &other) { - assert(rank() == other.rank() && "rank mismatch"); - bool changed = false; - for (unsigned i = 0; i < rank(); ++i) { - for (int64_t v : other.dims[i]) { - changed |= dims[i].insert(v).second; - } - } - return changed; - } - - /// Merge a single concrete tile size (one value per dimension). - /// Values of -1 (unknown) are skipped. If this object is uninitialized - /// (rank 0), it is initialized from the size of `concreteSizes`. - bool merge(ArrayRef concreteSizes) { + /// Merge candidates from `other` into this. Uninitialized is identity. + void merge(const TileSizeCandidates &other) { if (empty()) { - dims.resize(concreteSizes.size()); + *this = other; + return; + } + if (other.empty()) { + return; } - assert(rank() == concreteSizes.size() && "rank mismatch"); - bool changed = false; + assert(rank() == other.rank() && "rank mismatch"); for (unsigned i = 0; i < rank(); ++i) { - if (concreteSizes[i] != -1) { - changed |= dims[i].insert(concreteSizes[i]).second; - } + dims[i].insert_range(other.dims[i]); } - return changed; } /// Returns true if any dimension has more than one candidate. @@ -107,6 +159,34 @@ class TileSizeCandidates { return result; } + /// Lattice join: per-dimension set union. Uninitialized is identity. + static TileSizeCandidates join(const TileSizeCandidates &lhs, + const TileSizeCandidates &rhs) { + TileSizeCandidates result = lhs; + result.merge(rhs); + return result; + } + + /// Lattice meet: same as join (both directions accumulate via set union). + static TileSizeCandidates meet(const TileSizeCandidates &lhs, + const TileSizeCandidates &rhs) { + return join(lhs, rhs); + } + + bool operator==(const TileSizeCandidates &rhs) const { + return dims == rhs.dims; + } + + void print(raw_ostream &os) const { + os << "["; + llvm::interleaveComma(dims, os, [&](const TileSizeSet &s) { + os << "{"; + llvm::interleaveComma(s, os); + os << "}"; + }); + os << "]"; + } + private: SmallVector dims; }; @@ -143,273 +223,246 @@ static bool isDuplicatable(Value val) { return false; } -struct TileSizeState { - void propagateForward(Value val); - void propagateBackward(Value val); +//===----------------------------------------------------------------------===// +// Lattice and analysis definitions +//===----------------------------------------------------------------------===// - /// Merge candidates into a value and enqueue if anything changed. - void mergeAndEnqueue(Value val, const TileSizeCandidates &candidates) { - if (!isa(val.getType())) { - return; - } - if (candidates.empty()) { - return; - } - // If val is not yet in the map, inserting it may rehash the DenseMap - // and invalidate `candidates` if it aliases an existing entry. Copy - // directly into the new entry to avoid the dangling reference. - if (!tileSizes.count(val)) { - tileSizes[val] = candidates; - } else { - if (!tileSizes[val].merge(candidates)) { - return; - } - } - // We don't forward multiple alternatives from operations that are easy to - // duplicate. CSE will deduplicate DPS init operands, creating edges between - // unrelated compute operations. Propagating different vector tile sizes via - // shared DPS inits doesn't provide any value in that case. - if (isDuplicatable(val) && tileSizes[val].hasAlternatives()) { - return; - } - // Propagate the update. - forward.push(val); - backward.push(val); - } +class TileSizeLattice : public dataflow::Lattice { +public: + using Lattice::Lattice; +}; - /// Convenience: merge a single concrete tile size and enqueue if changed. - void mergeAndEnqueue(Value val, ArrayRef concreteSizes) { - TileSizeCandidates candidates(concreteSizes.size()); - candidates.merge(concreteSizes); - mergeAndEnqueue(val, candidates); +/// Read the TileSizeCandidates from a lattice, returning empty candidates +/// if the lattice value is duplicatable with alternatives. +static const TileSizeCandidates & +getCandidatesFor(Value val, const TileSizeLattice *lattice) { + static const TileSizeCandidates empty; + if (!lattice) { + return empty; } - - bool hasTileSize(Value val) const { return tileSizes.count(val); } - - const TileSizeCandidates &getCandidates(Value val) const { - static const TileSizeCandidates empty; - auto it = tileSizes.find(val); - if (it == tileSizes.end()) { - return empty; - } - return it->second; + auto &candidates = lattice->getValue(); + if (candidates.empty()) { + return empty; } - - /// Propagate through a linalg.generic: given known tile sizes on some - /// operands, infer tile sizes for other operands via indexing maps. - void propagateGenericOp(linalg::GenericOp genericOp); - - DenseMap tileSizes; - std::queue forward; - std::queue backward; -}; - -/// Collect per-dimension tile size candidate sets from a linalg op's operands. -/// Returns a TileSizeCandidates of size numLoops, where each dimension is the -/// union of all candidate tile sizes for that iteration dimension across all -/// operands. -static TileSizeCandidates -getIterationSpaceTileSizes(linalg::LinalgOp linalgOp, - const TileSizeState &state) { - unsigned numLoops = linalgOp.getNumLoops(); - TileSizeCandidates result(numLoops); - for (OpOperand &operand : linalgOp->getOpOperands()) { - auto &candidates = state.getCandidates(operand.get()); - if (candidates.empty()) { - continue; - } - AffineMap map = linalgOp.getMatchingIndexingMap(&operand); - auto mapped = candidates.mapToIterationSpace(map, numLoops); - result.merge(mapped); + if (isDuplicatable(val) && candidates.hasAlternatives()) { + return empty; } - return result; + return candidates; } -void TileSizeState::propagateGenericOp(linalg::GenericOp genericOp) { - auto perDimSizes = getIterationSpaceTileSizes(genericOp, *this); - - // Map per-dimension iteration-space candidates to each operand's dimensions - // via its indexing map. - for (OpOperand &operand : genericOp->getOpOperands()) { - AffineMap map = genericOp.getMatchingIndexingMap(&operand); - auto operandCandidates = perDimSizes.mapFromIterationSpace(map); - if (operandCandidates.empty()) { - continue; - } - mergeAndEnqueue(operand.get(), operandCandidates); +/// Forward analysis: propagates tile size candidates from operands to results. +/// Control flow through scf.for/scf.if is handled automatically by the +/// framework via RegionBranchOpInterface. +class TileSizeForwardAnalysis + : public dataflow::SparseForwardDataFlowAnalysis { +public: + using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; + + LogicalResult initialize(Operation *top) override { + // Seed to_layout anchors before the regular initialization. This ensures + // seeds are set even for to_layout ops in regions that DeadCodeAnalysis + // hasn't yet marked as live during init. + top->walk([&](ToLayoutOp toLayout) { + LDBG() << "Anchor: " << toLayout; + auto candidates = TileSizeCandidates::fromSizes( + toLayout.getLayout().getUndistributedShape()); + auto *lattice = getLatticeElement(toLayout.getResult()); + propagateIfChanged(lattice, lattice->join(candidates)); + }); + return SparseForwardDataFlowAnalysis::initialize(top); } - // Propagate to results via their corresponding init operands. - for (auto [init, result] : - llvm::zip_equal(genericOp.getDpsInits(), genericOp.getResults())) { - mergeAndEnqueue(result, getCandidates(init)); + void setToEntryState(TileSizeLattice *lattice) override { + // Entry state is uninitialized (identity for join). + propagateIfChanged(lattice, lattice->join(TileSizeCandidates())); } -} -void TileSizeState::propagateForward(Value val) { - auto &candidates = getCandidates(val); - if (candidates.empty()) { - return; - } - LDBG() << "Propagating tile size forward for: " << val; - - for (OpOperand &use : val.getUses()) { - Operation *user = use.getOwner(); - unsigned operandIdx = use.getOperandNumber(); - - // scf.for: propagate to tied loop body arg and result. - if (auto forOp = dyn_cast(user)) { - Value arg = forOp.getTiedLoopRegionIterArg(&use); - Value result = forOp.getTiedLoopResult(&use); - mergeAndEnqueue(arg, candidates); - mergeAndEnqueue(result, candidates); - continue; + LogicalResult visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) override { + // to_layout: don't propagate operand forward (anchor boundary). + // Seeding is done in initialize(). + if (isa(op)) { + return success(); } - // scf.yield: propagate to parent op's results/args. - if (auto yieldOp = dyn_cast(user)) { - Operation *parentOp = yieldOp->getParentOp(); - if (auto forOp = dyn_cast(parentOp)) { - Value arg = forOp.getRegionIterArg(operandIdx); - Value result = forOp->getResult(operandIdx); - mergeAndEnqueue(arg, candidates); - mergeAndEnqueue(result, candidates); - continue; + // linalg.generic: propagate through indexing maps. + if (auto genericOp = dyn_cast(op)) { + unsigned numLoops = genericOp.getNumLoops(); + // Combine the information from all operands into a single candidate in + // iteration space. + TileSizeCandidates iterCandidates(numLoops); + for (OpOperand &operand : genericOp->getOpOperands()) { + auto &candidates = getCandidatesFor( + operand.get(), operands[operand.getOperandNumber()]); + if (candidates.empty()) { + continue; + } + AffineMap map = genericOp.getMatchingIndexingMap(&operand); + iterCandidates.merge(candidates.mapToIterationSpace(map, numLoops)); } - if (auto ifOp = dyn_cast(parentOp)) { - Value result = ifOp->getResult(operandIdx); - mergeAndEnqueue(result, candidates); - continue; + // For each result, map the combined candidate in iteration space back to + // the result (DPS init operand) space. + for (unsigned i = 0; i < genericOp.getNumDpsInits(); ++i) { + OpOperand *init = genericOp.getDpsInitOperand(i); + AffineMap map = genericOp.getMatchingIndexingMap(init); + auto resultCandidates = iterCandidates.mapFromIterationSpace(map); + if (!resultCandidates.empty()) { + propagateIfChanged(results[i], results[i]->join(resultCandidates)); + } } + return success(); } // Elementwise ops: propagate to all results. - if (OpTrait::hasElementwiseMappableTraits(user)) { - for (OpResult result : user->getOpResults()) { - mergeAndEnqueue(result, candidates); + if (OpTrait::hasElementwiseMappableTraits(op)) { + TileSizeCandidates combined; + for (auto [operandLattice, operandVal] : + llvm::zip(operands, op->getOperands())) { + combined.merge(getCandidatesFor(operandVal, operandLattice)); } - continue; + for (TileSizeLattice *result : results) { + propagateIfChanged(result, result->join(combined)); + } + return success(); } - // linalg.generic: propagate through indexing maps. - if (auto genericOp = dyn_cast(user)) { - propagateGenericOp(genericOp); - continue; - } + return success(); } -} +}; -void TileSizeState::propagateBackward(Value val) { - LDBG() << "Propagating tile size backward for: " << val; - auto &candidates = getCandidates(val); - if (candidates.empty()) { - return; +/// Backward analysis: propagates tile size candidates from results to operands. +/// Control flow through scf.for/scf.if is handled automatically by the +/// framework via RegionBranchOpInterface. +class TileSizeBackwardAnalysis + : public dataflow::SparseBackwardDataFlowAnalysis { +public: + using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; + + void setToExitState(TileSizeLattice *lattice) override { + // Exit state is uninitialized (identity for meet). } - // Block arguments (e.g., scf.for iter_args). - if (auto blockArg = dyn_cast(val)) { - Operation *parent = val.getParentBlock()->getParentOp(); - if (auto forOp = dyn_cast(parent)) { - OpOperand *yielded = forOp.getTiedLoopYieldedValue(blockArg); - OpOperand *init = forOp.getTiedLoopInit(blockArg); - if (yielded) { - mergeAndEnqueue(yielded->get(), candidates); - } - if (init) { - mergeAndEnqueue(init->get(), candidates); + LogicalResult + visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override { + // to_layout: propagate result tile sizes backward to input. + if (auto toLayout = dyn_cast(op)) { + auto &candidates = getCandidatesFor(toLayout.getResult(), results[0]); + if (!candidates.empty()) { + TileSizeLattice *inputLattice = operands[0]; + propagateIfChanged(inputLattice, inputLattice->meet(candidates)); } + return success(); } - return; - } - Operation *defOp = val.getDefiningOp(); - if (!defOp) { - return; - } - - // Elementwise ops: propagate to all operands. - if (OpTrait::hasElementwiseMappableTraits(defOp)) { - for (OpOperand &operand : defOp->getOpOperands()) { - if (isa(operand.get().getType())) { - mergeAndEnqueue(operand.get(), candidates); + // linalg.generic: propagate through indexing maps. + if (auto genericOp = dyn_cast(op)) { + unsigned numLoops = genericOp.getNumLoops(); + TileSizeCandidates iterCandidates(numLoops); + // Gather result candidates into iteration space via DPS init maps. + for (auto [result, resultLattice] : + llvm::zip(genericOp.getResults(), results)) { + auto &candidates = getCandidatesFor(result, resultLattice); + if (candidates.empty()) { + continue; + } + unsigned resultIdx = cast(result).getResultNumber(); + OpOperand *init = genericOp.getDpsInitOperand(resultIdx); + AffineMap map = genericOp.getMatchingIndexingMap(init); + iterCandidates.merge(candidates.mapToIterationSpace(map, numLoops)); + } + // Gather operand candidates into iteration space. + for (OpOperand &operand : genericOp->getOpOperands()) { + auto &candidates = getCandidatesFor( + operand.get(), operands[operand.getOperandNumber()]); + if (candidates.empty()) { + continue; + } + AffineMap map = genericOp.getMatchingIndexingMap(&operand); + iterCandidates.merge(candidates.mapToIterationSpace(map, numLoops)); } + // Map iteration space candidates back to each operand. + for (OpOperand &operand : genericOp->getOpOperands()) { + AffineMap map = genericOp.getMatchingIndexingMap(&operand); + auto operandCandidates = iterCandidates.mapFromIterationSpace(map); + if (operandCandidates.empty()) { + continue; + } + TileSizeLattice *operandLattice = operands[operand.getOperandNumber()]; + propagateIfChanged(operandLattice, + operandLattice->meet(operandCandidates)); + } + return success(); } - return; - } - // linalg.generic: propagate through indexing maps. - if (auto genericOp = dyn_cast(defOp)) { - unsigned resultIdx = cast(val).getResultNumber(); - mergeAndEnqueue(genericOp.getDpsInitOperand(resultIdx)->get(), candidates); - propagateGenericOp(genericOp); - return; - } + // Elementwise ops: propagate to all operands. + if (OpTrait::hasElementwiseMappableTraits(op)) { + TileSizeCandidates combined; + for (auto [resultVal, resultLattice] : + llvm::zip(op->getResults(), results)) { + combined.merge(getCandidatesFor(resultVal, resultLattice)); + } + for (auto [operandLattice, operandVal] : + llvm::zip(operands, op->getOperands())) { + if (!isa(operandVal.getType())) { + continue; + } + propagateIfChanged(operandLattice, operandLattice->meet(combined)); + } + return success(); + } - // to_layout: propagate to input. - // We only propagate backward for to_layout, not forward, as to_layout is an - // anchor for initialization itself. - if (auto toLayout = dyn_cast(defOp)) { - mergeAndEnqueue(toLayout.getInput(), candidates); - return; + return success(); } - // scf.for results: propagate to yield and init. - if (auto forOp = dyn_cast(defOp)) { - unsigned resultIdx = cast(val).getResultNumber(); - Value init = forOp.getInits()[resultIdx]; - auto yieldOp = cast(forOp.getBody()->getTerminator()); - mergeAndEnqueue(init, candidates); - mergeAndEnqueue(yieldOp.getOperand(resultIdx), candidates); - return; - } + // Required by the base class. Non-forwarded branch operands (e.g., loop + // bounds, conditions) are scalars irrelevant to tile size propagation. + // Forwarded values (iter_args, yields) are handled by the framework via + // RegionBranchOpInterface. + void visitBranchOperand(OpOperand &operand) override {} + void visitCallOperand(OpOperand &operand) override {} + void + visitNonControlFlowArguments(RegionSuccessor &successor, + ArrayRef arguments) override {} +}; - // scf.if results: propagate to yields in both regions. - if (auto ifOp = dyn_cast(defOp)) { - unsigned resultIdx = cast(val).getResultNumber(); - auto thenYield = cast(ifOp.thenBlock()->getTerminator()); - mergeAndEnqueue(thenYield.getOperand(resultIdx), candidates); - assert(ifOp.elseBlock() && "scf.if with results must have an else block"); - auto elseYield = cast(ifOp.elseBlock()->getTerminator()); - mergeAndEnqueue(elseYield.getOperand(resultIdx), candidates); - return; - } -} +//===----------------------------------------------------------------------===// +// Result querying +//===----------------------------------------------------------------------===// -/// Run the VectorTileSizeAnalysis on the given root operation. -static void runAnalysis(Operation *root, TileSizeState &state) { - // Initialize from to_layout anchors. - root->walk([&](ToLayoutOp toLayout) { - SmallVector undistShape = - toLayout.getLayout().getUndistributedShape(); - LDBG() << "Anchor: " << toLayout; - state.mergeAndEnqueue(toLayout.getResult(), undistShape); - }); - - // Fixpoint iteration: forward first, then backward. - while (!state.forward.empty() || !state.backward.empty()) { - if (!state.forward.empty()) { - Value val = state.forward.front(); - state.forward.pop(); - state.propagateForward(val); - } else { - Value val = state.backward.front(); - state.backward.pop(); - state.propagateBackward(val); +/// Gather tile size candidates into the iteration space of a linalg op by +/// looking up each operand's lattice state in the solver. +static TileSizeCandidates +getIterationSpaceTileSizes(linalg::LinalgOp linalgOp, + const DataFlowSolver &solver) { + unsigned numLoops = linalgOp.getNumLoops(); + TileSizeCandidates iterCandidates(numLoops); + for (OpOperand &operand : linalgOp->getOpOperands()) { + Value val = operand.get(); + auto *lattice = solver.lookupState(val); + auto &candidates = getCandidatesFor(val, lattice); + if (candidates.empty()) { + continue; } + AffineMap map = linalgOp.getMatchingIndexingMap(&operand); + iterCandidates.merge(candidates.mapToIterationSpace(map, numLoops)); } + return iterCandidates; } -/// Given a linalg op and the analysis state, compute per-dimension sets of +/// Given a linalg op and the solver, compute per-dimension sets of /// candidate tile sizes. Returns a vector of size numLoops, where each entry /// is the deduplicated set of tile sizes for that iteration dimension. /// Returns an empty vector if any dimension has no candidates. static SmallVector> -getPerDimTileSizes(linalg::LinalgOp linalgOp, const TileSizeState &state) { - auto perDimSizes = getIterationSpaceTileSizes(linalgOp, state); +getPerDimTileSizes(linalg::LinalgOp linalgOp, const DataFlowSolver &solver) { + auto perDimSizes = getIterationSpaceTileSizes(linalgOp, solver); // Return empty if any dimension has no candidates. + unsigned numLoops = linalgOp.getNumLoops(); SmallVector> results; - for (unsigned i = 0; i < perDimSizes.rank(); ++i) { + for (unsigned i = 0; i < numLoops; ++i) { if (perDimSizes[i].empty()) { return {}; } @@ -435,11 +488,18 @@ class MaterializeVectorTileSizesPass final void runOnOperation() override { auto funcOp = getOperation(); - TileSizeState state; - runAnalysis(funcOp, state); + DataFlowSolver solver; + dataflow::loadBaselineAnalyses(solver); + solver.load(); + SymbolTableCollection symbolTable; + solver.load(symbolTable); + + if (failed(solver.initializeAndRun(funcOp))) { + return signalPassFailure(); + } funcOp->walk([&](linalg::LinalgOp linalgOp) { - auto perDimSizes = getPerDimTileSizes(linalgOp, state); + auto perDimSizes = getPerDimTileSizes(linalgOp, solver); if (perDimSizes.empty()) { return; } From 68726e64111b1bc63a8f2d1477fb79c7c902b6e5 Mon Sep 17 00:00:00 2001 From: Lukas Sommer Date: Mon, 16 Mar 2026 14:54:53 +0000 Subject: [PATCH 3/8] Address PR feedback Signed-off-by: Lukas Sommer --- .../Codegen/Common/VectorTileSizeAnalysis.cpp | 55 +++++++------------ .../test/materialize_vector_tile_sizes.mlir | 47 ++++++++++++++++ 2 files changed, 67 insertions(+), 35 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp index 098cfe602a5e..51b6631e60ff 100644 --- a/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp @@ -21,10 +21,10 @@ #define DEBUG_TYPE "iree-codegen-vector-tile-size-analysis" // The purpose of this analysis is to propagate information about the -// undistributed vector tile size across the operation graph. The vector tile -// size is important information for the vectorization of operations. -// For example, the vector tile size can be used by GenericVectorization to -// introduce the necessary masking in the presence of padding/masking. +// vector tile size across the operation graph. The vector tile size is +// important information for the vectorization of operations. For example, the +// vector tile size can be used by GenericVectorization to introduce the +// necessary masking in the presence of padding/masking. // // The analysis is a bi-directional dataflow analysis building on top of the // upstream MLIR dataflow analysis framework. To implement the bi-directional @@ -43,7 +43,8 @@ // As the set union can not result in a conflict, no lattice state for top // (overdefined) is required in this lattice. // -// The lattice is initialized from `to_layout` operations. +// The lattice is initialized from anchor operations that provide information +// about vector tile size (e.g., `to_layout`). // // Forward propagation and backward propagation work similarly: // - For elementwise operations, candidates from the different operands @@ -204,19 +205,14 @@ static bool isDuplicatable(Value val) { if (defOp->hasTrait()) { return true; } - // Catches linalg.fill that has been lowered/fused into linalg.generic form - // (scalar input broadcast into tensor.empty output). - if (auto genericOp = dyn_cast(defOp)) { - if (genericOp.getNumDpsInputs() == 1 && genericOp.getNumDpsInits() == 1 && - !isa(genericOp.getDpsInputs()[0].getType())) { - Value init = genericOp.getDpsInits()[0]; - if (init.getDefiningOp()) { - return true; - } - } - } - if (auto fillOp = dyn_cast(defOp)) { - if (fillOp.getOutputs()[0].getDefiningOp()) { + // A linalg op that doesn't read any tensor data (e.g., linalg.fill or a + // fill-like linalg.generic broadcasting a scalar) is a generator and + // duplicatable. + if (auto linalgOp = dyn_cast(defOp)) { + if (llvm::none_of(linalgOp->getOpOperands(), [&](OpOperand &operand) { + return isa(operand.get().getType()) && + linalgOp.payloadUsesValueFromOperand(&operand); + })) { return true; } } @@ -258,20 +254,6 @@ class TileSizeForwardAnalysis public: using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; - LogicalResult initialize(Operation *top) override { - // Seed to_layout anchors before the regular initialization. This ensures - // seeds are set even for to_layout ops in regions that DeadCodeAnalysis - // hasn't yet marked as live during init. - top->walk([&](ToLayoutOp toLayout) { - LDBG() << "Anchor: " << toLayout; - auto candidates = TileSizeCandidates::fromSizes( - toLayout.getLayout().getUndistributedShape()); - auto *lattice = getLatticeElement(toLayout.getResult()); - propagateIfChanged(lattice, lattice->join(candidates)); - }); - return SparseForwardDataFlowAnalysis::initialize(top); - } - void setToEntryState(TileSizeLattice *lattice) override { // Entry state is uninitialized (identity for join). propagateIfChanged(lattice, lattice->join(TileSizeCandidates())); @@ -280,9 +262,12 @@ class TileSizeForwardAnalysis LogicalResult visitOperation(Operation *op, ArrayRef operands, ArrayRef results) override { - // to_layout: don't propagate operand forward (anchor boundary). - // Seeding is done in initialize(). - if (isa(op)) { + // to_layout: seed from layout, don't propagate operand forward. + if (auto toLayout = dyn_cast(op)) { + LDBG() << "Anchor: " << toLayout; + auto candidates = TileSizeCandidates::fromSizes( + toLayout.getLayout().getUndistributedShape()); + propagateIfChanged(results[0], results[0]->join(candidates)); return success(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir b/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir index 3e6f0f3533c3..e49639b45291 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir @@ -79,6 +79,53 @@ func.func @chain_propagation_transpose( // ----- +// Chain propagation with dynamic shapes: tile sizes propagate the same way +// regardless of whether tensor dimensions are static or dynamic. + +#layout_dyn = #iree_vector_ext.nested_layout< + subgroup_tile = [1, 1], batch_tile = [1, 8], outer_tile = [1, 1], + thread_tile = [1, 1], element_tile = [8, 8], + subgroup_strides = [0, 0], thread_strides = [0, 0]> + +// CHECK-LABEL: @chain_propagation_dynamic +func.func @chain_propagation_dynamic( + %arg0: tensor, %arg1: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %a = iree_vector_ext.to_layout %arg0 to layout(#layout_dyn) : tensor + %empty_ab = tensor.empty(%d0, %d1) : tensor + // CHECK: linalg.generic + // CHECK-SAME: iree_codegen.vector_tile_sizes = [array, array] + %ab = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%a, %arg1 : tensor, tensor) + outs(%empty_ab : tensor) { + ^bb0(%in0: f32, %in1: f32, %out: f32): + %add = arith.addf %in0, %in1 : f32 + linalg.yield %add : f32 + } -> tensor + %empty_c = tensor.empty(%d0, %d1) : tensor + // CHECK: linalg.generic + // CHECK-SAME: iree_codegen.vector_tile_sizes = [array, array] + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%ab : tensor) outs(%empty_c : tensor) { + ^bb0(%in: f32, %out: f32): + %neg = arith.negf %in : f32 + linalg.yield %neg : f32 + } -> tensor + return %result : tensor +} + +// ----- + // scf.for propagation through iter_args. // The to_layout inside the loop should propagate tile sizes to the // loop iter_args and through the scf.yield. From ae181ad358e752f74cb1489abdaa3b0d1b924462 Mon Sep 17 00:00:00 2001 From: Lukas Sommer Date: Wed, 18 Mar 2026 15:13:40 +0000 Subject: [PATCH 4/8] Track only one tile size per dim Signed-off-by: Lukas Sommer --- .../Codegen/Common/GenericVectorization.cpp | 23 +- .../Codegen/Common/VectorTileSizeAnalysis.cpp | 303 +++++++++--------- .../test/materialize_vector_tile_sizes.mlir | 58 +++- 3 files changed, 204 insertions(+), 180 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp index be265a29db97..d3cd248b4d1e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp @@ -83,25 +83,12 @@ getVectorSizes(Operation *op, bool useConfiguredVectorSizes) { } // Try to get vector sizes from materialized tile size attribute. - // The attribute is an array of per-dimension candidate lists; use the - // maximum from each dimension. if (auto tileSizesAttr = - op->getAttrOfType(kVectorTileSizesAttrName)) { - SmallVector vectorSizes; - bool valid = !tileSizesAttr.empty(); - for (auto dimAttr : tileSizesAttr) { - auto dimSizes = cast(dimAttr); - if (dimSizes.empty()) { - valid = false; - break; - } - vectorSizes.push_back(*llvm::max_element(dimSizes.asArrayRef())); - } - if (valid) { - LDBG() << "Use vector sizes from materialized tile size attribute"; - SmallVector scalableFlags(vectorSizes.size(), false); - return std::make_pair(vectorSizes, scalableFlags); - } + op->getAttrOfType(kVectorTileSizesAttrName)) { + LDBG() << "Use vector sizes from materialized tile size attribute"; + SmallVector vectorSizes(tileSizesAttr.asArrayRef()); + SmallVector scalableFlags(vectorSizes.size(), false); + return std::make_pair(vectorSizes, scalableFlags); } // Try to infer the vector sizes from the IR. diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp index 51b6631e60ff..b41ade6f56c1 100644 --- a/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp @@ -16,8 +16,6 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/SymbolTable.h" -#include "llvm/ADT/SmallSet.h" - #define DEBUG_TYPE "iree-codegen-vector-tile-size-analysis" // The purpose of this analysis is to propagate information about the @@ -33,21 +31,19 @@ // // The lattice for the dataflow analysis is shared by both analyses (forward and // backward). For each N-dimensional ShapedType SSA value, we have a lattice -// element comprising N sets, where each set contains the candidate tile sizes -// for that dimension. The bottom (uninitialized) state of the lattice is simply -// empty. The join/merge operation for two lattice elements is the per-dimension -// set-union of candidates. For example, in the 2D case: -// ({2, 4}, {16}) U ({8}, {32}) = ({2, 4, 8}, {16, 32}) -// -// As the sets can only grow, the join/meet operator is by definition monotonic. -// As the set union can not result in a conflict, no lattice state for top -// (overdefined) is required in this lattice. +// element comprising N dimensions, where each dimension is in one of three +// states: uninitialized (bottom), a single tile size value, or overdefined +// (top). The bottom (uninitialized) state is the identity for the merge +// operation. Merging two equal tile sizes for the same dimension is identity. +// Merging two different tile sizes for the same dimension results in the +// overdefined state. Overdefined is absorbing — once a dimension reaches +// overdefined, it stays overdefined. // // The lattice is initialized from anchor operations that provide information // about vector tile size (e.g., `to_layout`). // // Forward propagation and backward propagation work similarly: -// - For elementwise operations, candidates from the different operands +// - For elementwise operations, tile sizes from the different operands // (forward) or results (backwards) are merged. The merged lattice state is // then propagated to all results (forward) or operands (backward). // - For linalg.generic operations, all available information from operands @@ -57,12 +53,12 @@ // operand (backward) based on indexing maps and the mapped state is // propagated. // -// The only exception to this process are duplicatable operations such as -// `tensor.empty`. CSE connects otherwise unrelated compute ops by deduplicating -// their DPS init operands to a single tensor.empty (or similar). To avoid -// cross-polluting the vector tile size of unrelated operations, propagation -// from duplicatable operations is stopped if they contain multiple candidates -// tile sizes in at least one dimension. +// Duplicatable operations such as `tensor.empty`, constants, and generator +// linalg ops (e.g. linalg.fill) are excluded from propagation entirely. CSE +// connects otherwise unrelated compute ops by deduplicating their DPS init +// operands to a single tensor.empty (or similar). To avoid cross-polluting the +// vector tile size of unrelated operations, tile sizes from duplicatable +// operations are never propagated. // // For some other operations, no propagation rules are defined on purpose. For // example, `extract_slice` and `insert_slice` operations are natural boundaries @@ -70,41 +66,65 @@ // // After the dataflow solver reaches a fixpoint, the // MaterializeVectorTileSizesPass materializes the result as a discardable -// attribute. At this point, the result is a set of candidate vector tile sizes -// per iteration dimension. It is up to the users of the analysis how to select -// a tile size from the set of candidates. +// attribute. Only dimensions with a single defined (non-overdefined) tile size +// are materialized. Operations where any dimension is uninitialized or +// overdefined do not receive the attribute. namespace mlir::iree_compiler { using namespace IREE::VectorExt; -using TileSizeSet = llvm::SmallSet; +/// Sentinel values for per-dimension tile size state. Valid tile sizes are +/// always positive. +constexpr int64_t kUninitialized = 0; +constexpr int64_t kOverdefined = -1; + +/// Merge a single dimension's tile size value. Returns the merged result. +static int64_t mergeDim(int64_t a, int64_t b) { + if (a == kOverdefined || b == kOverdefined) { + return kOverdefined; + } + if (a == kUninitialized) { + return b; + } + if (b == kUninitialized) { + return a; + } + return a == b ? a : kOverdefined; +} -/// Per-dimension tile size candidates. Each dimension has an independent set -/// of candidate tile sizes. Satisfies the requirements for use as a -/// `Lattice` value type. -class TileSizeCandidates { +/// Per-dimension tile sizes. Each dimension holds a single tile size value, or +/// one of the sentinel values kUninitialized/kOverdefined. Satisfies the +/// requirements for use as a `Lattice` value type. +class TileSizes { public: - TileSizeCandidates() = default; - explicit TileSizeCandidates(unsigned rank) : dims(rank) {} + TileSizes() = default; + explicit TileSizes(unsigned rank) : dims(rank, kUninitialized) {} /// Construct from a single concrete tile size (one value per dimension). - static TileSizeCandidates fromSizes(ArrayRef sizes) { - TileSizeCandidates result(sizes.size()); - for (unsigned i = 0; i < sizes.size(); ++i) { - result.dims[i].insert(sizes[i]); - } + static TileSizes fromSizes(ArrayRef sizes) { + TileSizes result; + result.dims.assign(sizes.begin(), sizes.end()); return result; } unsigned rank() const { return dims.size(); } bool empty() const { return dims.empty(); } - const TileSizeSet &operator[](unsigned i) const { return dims[i]; } - TileSizeSet &operator[](unsigned i) { return dims[i]; } + int64_t operator[](unsigned i) const { return dims[i]; } + + /// Returns true if all dimensions have a defined (positive) tile size. + bool isDefined() const { + return !empty() && llvm::all_of(dims, [](int64_t v) { return v > 0; }); + } + + /// Returns true if any dimension is overdefined. + bool isOverdefined() const { + return llvm::any_of(dims, [](int64_t v) { return v == kOverdefined; }); + } - /// Merge candidates from `other` into this. Uninitialized is identity. - void merge(const TileSizeCandidates &other) { + /// Merge tile sizes from `other` into this. Uninitialized is identity. + void merge(const TileSizes &other) { if (empty()) { *this = other; return; @@ -114,45 +134,37 @@ class TileSizeCandidates { } assert(rank() == other.rank() && "rank mismatch"); for (unsigned i = 0; i < rank(); ++i) { - dims[i].insert_range(other.dims[i]); + dims[i] = mergeDim(dims[i], other.dims[i]); } } - /// Returns true if any dimension has more than one candidate. - bool hasAlternatives() const { - return llvm::any_of(dims, - [](const TileSizeSet &s) { return s.size() > 1; }); - } - /// Map from operand space to iteration space via an indexing map. - TileSizeCandidates mapToIterationSpace(AffineMap indexingMap, - unsigned numLoops) const { - TileSizeCandidates result(numLoops); + TileSizes mapToIterationSpace(AffineMap indexingMap, + unsigned numLoops) const { + TileSizes result(numLoops); for (unsigned i = 0; i < indexingMap.getNumResults(); ++i) { auto dimExpr = dyn_cast(indexingMap.getResult(i)); if (!dimExpr) { continue; } unsigned iterDim = dimExpr.getPosition(); - for (int64_t v : dims[i]) { - result.dims[iterDim].insert(v); - } + result.dims[iterDim] = mergeDim(result.dims[iterDim], dims[i]); } return result; } /// Map from iteration space to operand space via an indexing map. - /// Returns empty TileSizeCandidates if any operand dim can't be determined. - TileSizeCandidates mapFromIterationSpace(AffineMap indexingMap) const { + /// Returns empty TileSizes if any operand dim can't be determined. + TileSizes mapFromIterationSpace(AffineMap indexingMap) const { unsigned numResults = indexingMap.getNumResults(); - TileSizeCandidates result(numResults); + TileSizes result(numResults); for (unsigned i = 0; i < numResults; ++i) { auto dimExpr = dyn_cast(indexingMap.getResult(i)); if (!dimExpr) { return {}; } unsigned iterDim = dimExpr.getPosition(); - if (iterDim >= rank() || dims[iterDim].empty()) { + if (iterDim >= rank() || dims[iterDim] == kUninitialized) { return {}; } result.dims[i] = dims[iterDim]; @@ -160,40 +172,40 @@ class TileSizeCandidates { return result; } - /// Lattice join: per-dimension set union. Uninitialized is identity. - static TileSizeCandidates join(const TileSizeCandidates &lhs, - const TileSizeCandidates &rhs) { - TileSizeCandidates result = lhs; + /// Lattice join: per-dimension merge. Uninitialized is identity. + static TileSizes join(const TileSizes &lhs, const TileSizes &rhs) { + TileSizes result = lhs; result.merge(rhs); return result; } - /// Lattice meet: same as join (both directions accumulate via set union). - static TileSizeCandidates meet(const TileSizeCandidates &lhs, - const TileSizeCandidates &rhs) { + /// Lattice meet: same as join (both directions merge the same way). + static TileSizes meet(const TileSizes &lhs, const TileSizes &rhs) { return join(lhs, rhs); } - bool operator==(const TileSizeCandidates &rhs) const { - return dims == rhs.dims; - } + bool operator==(const TileSizes &rhs) const { return dims == rhs.dims; } void print(raw_ostream &os) const { os << "["; - llvm::interleaveComma(dims, os, [&](const TileSizeSet &s) { - os << "{"; - llvm::interleaveComma(s, os); - os << "}"; + llvm::interleaveComma(dims, os, [&](int64_t v) { + if (v == kUninitialized) { + os << "?"; + } else if (v == kOverdefined) { + os << "T"; + } else { + os << v; + } }); os << "]"; } private: - SmallVector dims; + SmallVector dims; }; /// Returns true if the operation is trivially duplicatable and should not -/// propagate merged tile sizes across independent consumers. +/// propagate tile sizes across independent consumers. static bool isDuplicatable(Value val) { Operation *defOp = val.getDefiningOp(); if (!defOp) { @@ -223,30 +235,30 @@ static bool isDuplicatable(Value val) { // Lattice and analysis definitions //===----------------------------------------------------------------------===// -class TileSizeLattice : public dataflow::Lattice { +class TileSizeLattice : public dataflow::Lattice { public: using Lattice::Lattice; }; -/// Read the TileSizeCandidates from a lattice, returning empty candidates -/// if the lattice value is duplicatable with alternatives. -static const TileSizeCandidates & -getCandidatesFor(Value val, const TileSizeLattice *lattice) { - static const TileSizeCandidates empty; +/// Read the TileSizes from a lattice, returning empty tile sizes if the lattice +/// value is from a duplicatable operation. +static const TileSizes &getTileSizesFor(Value val, + const TileSizeLattice *lattice) { + static const TileSizes empty; if (!lattice) { return empty; } - auto &candidates = lattice->getValue(); - if (candidates.empty()) { + auto &tileSizes = lattice->getValue(); + if (tileSizes.empty()) { return empty; } - if (isDuplicatable(val) && candidates.hasAlternatives()) { + if (isDuplicatable(val)) { return empty; } - return candidates; + return tileSizes; } -/// Forward analysis: propagates tile size candidates from operands to results. +/// Forward analysis: propagates tile sizes from operands to results. /// Control flow through scf.for/scf.if is handled automatically by the /// framework via RegionBranchOpInterface. class TileSizeForwardAnalysis @@ -256,7 +268,7 @@ class TileSizeForwardAnalysis void setToEntryState(TileSizeLattice *lattice) override { // Entry state is uninitialized (identity for join). - propagateIfChanged(lattice, lattice->join(TileSizeCandidates())); + propagateIfChanged(lattice, lattice->join(TileSizes())); } LogicalResult visitOperation(Operation *op, @@ -265,35 +277,31 @@ class TileSizeForwardAnalysis // to_layout: seed from layout, don't propagate operand forward. if (auto toLayout = dyn_cast(op)) { LDBG() << "Anchor: " << toLayout; - auto candidates = TileSizeCandidates::fromSizes( - toLayout.getLayout().getUndistributedShape()); - propagateIfChanged(results[0], results[0]->join(candidates)); + auto tileSizes = + TileSizes::fromSizes(toLayout.getLayout().getUndistributedShape()); + propagateIfChanged(results[0], results[0]->join(tileSizes)); return success(); } // linalg.generic: propagate through indexing maps. if (auto genericOp = dyn_cast(op)) { unsigned numLoops = genericOp.getNumLoops(); - // Combine the information from all operands into a single candidate in - // iteration space. - TileSizeCandidates iterCandidates(numLoops); + TileSizes iterTileSizes(numLoops); for (OpOperand &operand : genericOp->getOpOperands()) { - auto &candidates = getCandidatesFor( - operand.get(), operands[operand.getOperandNumber()]); - if (candidates.empty()) { + auto &ts = getTileSizesFor(operand.get(), + operands[operand.getOperandNumber()]); + if (ts.empty()) { continue; } AffineMap map = genericOp.getMatchingIndexingMap(&operand); - iterCandidates.merge(candidates.mapToIterationSpace(map, numLoops)); + iterTileSizes.merge(ts.mapToIterationSpace(map, numLoops)); } - // For each result, map the combined candidate in iteration space back to - // the result (DPS init operand) space. for (unsigned i = 0; i < genericOp.getNumDpsInits(); ++i) { OpOperand *init = genericOp.getDpsInitOperand(i); AffineMap map = genericOp.getMatchingIndexingMap(init); - auto resultCandidates = iterCandidates.mapFromIterationSpace(map); - if (!resultCandidates.empty()) { - propagateIfChanged(results[i], results[i]->join(resultCandidates)); + auto resultTileSizes = iterTileSizes.mapFromIterationSpace(map); + if (!resultTileSizes.empty()) { + propagateIfChanged(results[i], results[i]->join(resultTileSizes)); } } return success(); @@ -301,10 +309,10 @@ class TileSizeForwardAnalysis // Elementwise ops: propagate to all results. if (OpTrait::hasElementwiseMappableTraits(op)) { - TileSizeCandidates combined; + TileSizes combined; for (auto [operandLattice, operandVal] : llvm::zip(operands, op->getOperands())) { - combined.merge(getCandidatesFor(operandVal, operandLattice)); + combined.merge(getTileSizesFor(operandVal, operandLattice)); } for (TileSizeLattice *result : results) { propagateIfChanged(result, result->join(combined)); @@ -316,7 +324,7 @@ class TileSizeForwardAnalysis } }; -/// Backward analysis: propagates tile size candidates from results to operands. +/// Backward analysis: propagates tile sizes from results to operands. /// Control flow through scf.for/scf.if is handled automatically by the /// framework via RegionBranchOpInterface. class TileSizeBackwardAnalysis @@ -333,10 +341,10 @@ class TileSizeBackwardAnalysis ArrayRef results) override { // to_layout: propagate result tile sizes backward to input. if (auto toLayout = dyn_cast(op)) { - auto &candidates = getCandidatesFor(toLayout.getResult(), results[0]); - if (!candidates.empty()) { + auto &ts = getTileSizesFor(toLayout.getResult(), results[0]); + if (!ts.empty()) { TileSizeLattice *inputLattice = operands[0]; - propagateIfChanged(inputLattice, inputLattice->meet(candidates)); + propagateIfChanged(inputLattice, inputLattice->meet(ts)); } return success(); } @@ -344,49 +352,49 @@ class TileSizeBackwardAnalysis // linalg.generic: propagate through indexing maps. if (auto genericOp = dyn_cast(op)) { unsigned numLoops = genericOp.getNumLoops(); - TileSizeCandidates iterCandidates(numLoops); - // Gather result candidates into iteration space via DPS init maps. + TileSizes iterTileSizes(numLoops); + // Gather result tile sizes into iteration space via DPS init maps. for (auto [result, resultLattice] : llvm::zip(genericOp.getResults(), results)) { - auto &candidates = getCandidatesFor(result, resultLattice); - if (candidates.empty()) { + auto &ts = getTileSizesFor(result, resultLattice); + if (ts.empty()) { continue; } unsigned resultIdx = cast(result).getResultNumber(); OpOperand *init = genericOp.getDpsInitOperand(resultIdx); AffineMap map = genericOp.getMatchingIndexingMap(init); - iterCandidates.merge(candidates.mapToIterationSpace(map, numLoops)); + iterTileSizes.merge(ts.mapToIterationSpace(map, numLoops)); } - // Gather operand candidates into iteration space. + // Gather operand tile sizes into iteration space. for (OpOperand &operand : genericOp->getOpOperands()) { - auto &candidates = getCandidatesFor( - operand.get(), operands[operand.getOperandNumber()]); - if (candidates.empty()) { + auto &ts = getTileSizesFor(operand.get(), + operands[operand.getOperandNumber()]); + if (ts.empty()) { continue; } AffineMap map = genericOp.getMatchingIndexingMap(&operand); - iterCandidates.merge(candidates.mapToIterationSpace(map, numLoops)); + iterTileSizes.merge(ts.mapToIterationSpace(map, numLoops)); } - // Map iteration space candidates back to each operand. + // Map iteration space tile sizes back to each operand. for (OpOperand &operand : genericOp->getOpOperands()) { AffineMap map = genericOp.getMatchingIndexingMap(&operand); - auto operandCandidates = iterCandidates.mapFromIterationSpace(map); - if (operandCandidates.empty()) { + auto operandTileSizes = iterTileSizes.mapFromIterationSpace(map); + if (operandTileSizes.empty()) { continue; } TileSizeLattice *operandLattice = operands[operand.getOperandNumber()]; propagateIfChanged(operandLattice, - operandLattice->meet(operandCandidates)); + operandLattice->meet(operandTileSizes)); } return success(); } // Elementwise ops: propagate to all operands. if (OpTrait::hasElementwiseMappableTraits(op)) { - TileSizeCandidates combined; + TileSizes combined; for (auto [resultVal, resultLattice] : llvm::zip(op->getResults(), results)) { - combined.merge(getCandidatesFor(resultVal, resultLattice)); + combined.merge(getTileSizesFor(resultVal, resultLattice)); } for (auto [operandLattice, operandVal] : llvm::zip(operands, op->getOperands())) { @@ -416,43 +424,39 @@ class TileSizeBackwardAnalysis // Result querying //===----------------------------------------------------------------------===// -/// Gather tile size candidates into the iteration space of a linalg op by -/// looking up each operand's lattice state in the solver. -static TileSizeCandidates -getIterationSpaceTileSizes(linalg::LinalgOp linalgOp, - const DataFlowSolver &solver) { +/// Gather tile sizes into the iteration space of a linalg op by looking up each +/// operand's lattice state in the solver. +static TileSizes getIterationSpaceTileSizes(linalg::LinalgOp linalgOp, + const DataFlowSolver &solver) { unsigned numLoops = linalgOp.getNumLoops(); - TileSizeCandidates iterCandidates(numLoops); + TileSizes iterTileSizes(numLoops); for (OpOperand &operand : linalgOp->getOpOperands()) { Value val = operand.get(); auto *lattice = solver.lookupState(val); - auto &candidates = getCandidatesFor(val, lattice); - if (candidates.empty()) { + auto &ts = getTileSizesFor(val, lattice); + if (ts.empty()) { continue; } AffineMap map = linalgOp.getMatchingIndexingMap(&operand); - iterCandidates.merge(candidates.mapToIterationSpace(map, numLoops)); + iterTileSizes.merge(ts.mapToIterationSpace(map, numLoops)); } - return iterCandidates; + return iterTileSizes; } -/// Given a linalg op and the solver, compute per-dimension sets of -/// candidate tile sizes. Returns a vector of size numLoops, where each entry -/// is the deduplicated set of tile sizes for that iteration dimension. -/// Returns an empty vector if any dimension has no candidates. -static SmallVector> +/// Given a linalg op and the solver, compute per-dimension tile sizes. +/// Returns a vector of one tile size per iteration dimension, or nullopt if +/// any dimension is uninitialized or overdefined. +static std::optional> getPerDimTileSizes(linalg::LinalgOp linalgOp, const DataFlowSolver &solver) { - auto perDimSizes = getIterationSpaceTileSizes(linalgOp, solver); - - // Return empty if any dimension has no candidates. + auto tileSizes = getIterationSpaceTileSizes(linalgOp, solver); + if (!tileSizes.isDefined()) { + return std::nullopt; + } unsigned numLoops = linalgOp.getNumLoops(); - SmallVector> results; + + SmallVector results; for (unsigned i = 0; i < numLoops; ++i) { - if (perDimSizes[i].empty()) { - return {}; - } - results.push_back( - SmallVector(perDimSizes[i].begin(), perDimSizes[i].end())); + results.push_back(tileSizes[i]); } return results; } @@ -485,19 +489,14 @@ class MaterializeVectorTileSizesPass final funcOp->walk([&](linalg::LinalgOp linalgOp) { auto perDimSizes = getPerDimTileSizes(linalgOp, solver); - if (perDimSizes.empty()) { + if (!perDimSizes) { return; } LDBG() << "Materializing tile size on " << *linalgOp; - - SmallVector dimAttrs; - for (const auto &dimSizes : perDimSizes) { - dimAttrs.push_back( - DenseI64ArrayAttr::get(linalgOp->getContext(), dimSizes)); - } - linalgOp->setAttr(kVectorTileSizesAttrName, - ArrayAttr::get(linalgOp->getContext(), dimAttrs)); + linalgOp->setAttr( + kVectorTileSizesAttrName, + DenseI64ArrayAttr::get(linalgOp->getContext(), *perDimSizes)); }); } }; diff --git a/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir b/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir index e49639b45291..3efe010d150a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir @@ -11,7 +11,7 @@ func.func @elementwise_from_anchor(%arg0: tensor<63xf16>) -> tensor<63xf16> { %empty = tensor.empty() : tensor<63xf16> // CHECK: linalg.generic - // CHECK-SAME: iree_codegen.vector_tile_sizes = [array] + // CHECK-SAME: iree_codegen.vector_tile_sizes = array %0 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"] @@ -22,7 +22,7 @@ func.func @elementwise_from_anchor(%arg0: tensor<63xf16>) -> tensor<63xf16> { } -> tensor<63xf16> %1 = iree_vector_ext.to_layout %0 to layout(#layout) : tensor<63xf16> // CHECK: linalg.generic - // CHECK-SAME: iree_codegen.vector_tile_sizes = [array] + // CHECK-SAME: iree_codegen.vector_tile_sizes = array %2 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"] @@ -50,7 +50,7 @@ func.func @chain_propagation_transpose( %a = iree_vector_ext.to_layout %arg0 to layout(#layout_2d) : tensor<8x64xf32> %empty_ab = tensor.empty() : tensor<8x64xf32> // CHECK: linalg.generic - // CHECK-SAME: iree_codegen.vector_tile_sizes = [array, array] + // CHECK-SAME: iree_codegen.vector_tile_sizes = array %ab = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, @@ -64,7 +64,7 @@ func.func @chain_propagation_transpose( } -> tensor<8x64xf32> %empty_t = tensor.empty() : tensor<64x8xf32> // CHECK: linalg.generic - // CHECK-SAME: iree_codegen.vector_tile_sizes = [array, array] + // CHECK-SAME: iree_codegen.vector_tile_sizes = array %result = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], @@ -97,7 +97,7 @@ func.func @chain_propagation_dynamic( %a = iree_vector_ext.to_layout %arg0 to layout(#layout_dyn) : tensor %empty_ab = tensor.empty(%d0, %d1) : tensor // CHECK: linalg.generic - // CHECK-SAME: iree_codegen.vector_tile_sizes = [array, array] + // CHECK-SAME: iree_codegen.vector_tile_sizes = array %ab = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, @@ -111,7 +111,7 @@ func.func @chain_propagation_dynamic( } -> tensor %empty_c = tensor.empty(%d0, %d1) : tensor // CHECK: linalg.generic - // CHECK-SAME: iree_codegen.vector_tile_sizes = [array, array] + // CHECK-SAME: iree_codegen.vector_tile_sizes = array %result = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], @@ -129,6 +129,8 @@ func.func @chain_propagation_dynamic( // scf.for propagation through iter_args. // The to_layout inside the loop should propagate tile sizes to the // loop iter_args and through the scf.yield. +// The fill-like %init generic is duplicatable, so it should NOT receive +// tile sizes. #layout = #iree_vector_ext.nested_layout< subgroup_tile = [8], batch_tile = [1], outer_tile = [1], @@ -140,7 +142,7 @@ func.func @scf_for_propagation(%arg0: tensor<512xf32>, %lb: index, %ub: index, % %empty = tensor.empty() : tensor<512xf32> %cst = arith.constant 0.0 : f32 // CHECK: linalg.generic - // CHECK-SAME: iree_codegen.vector_tile_sizes = [array] + // CHECK-NOT: iree_codegen.vector_tile_sizes %init = linalg.generic { indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"] @@ -151,7 +153,7 @@ func.func @scf_for_propagation(%arg0: tensor<512xf32>, %lb: index, %ub: index, % %result = scf.for %iv = %lb to %ub step %step iter_args(%iter = %init) -> tensor<512xf32> { %laid_out = iree_vector_ext.to_layout %iter to layout(#layout) : tensor<512xf32> // CHECK: linalg.generic - // CHECK-SAME: iree_codegen.vector_tile_sizes = [array] + // CHECK-SAME: iree_codegen.vector_tile_sizes = array %updated = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"] @@ -189,7 +191,7 @@ func.func @contraction_indexing_maps( %bl = iree_vector_ext.to_layout %b to layout(#layout_b) : tensor<512x63xf16> %cl = iree_vector_ext.to_layout %c to layout(#layout_c) : tensor<512xf32> // CHECK: linalg.generic - // CHECK-SAME: iree_codegen.vector_tile_sizes = [array, array] + // CHECK-SAME: iree_codegen.vector_tile_sizes = array %result = linalg.generic { indexing_maps = [ affine_map<(d0, d1) -> (d0)>, @@ -230,7 +232,7 @@ func.func @scf_if_propagation(%arg0: tensor<512xf32>, %cond: i1) -> tensor<512xf scf.yield %fill : tensor<512xf32> } // CHECK: linalg.generic - // CHECK-SAME: iree_codegen.vector_tile_sizes = [array] + // CHECK-SAME: iree_codegen.vector_tile_sizes = array %result = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"] @@ -241,3 +243,39 @@ func.func @scf_if_propagation(%arg0: tensor<512xf32>, %cond: i1) -> tensor<512xf } -> tensor<512xf32> return %result : tensor<512xf32> } + +// ----- + +// Conflicting tile sizes: two to_layout ops with different tile sizes for the +// same dimension feed into one generic. The dimension becomes overdefined, so +// no tile size attribute should be materialized. + +#layout_32 = #iree_vector_ext.nested_layout< + subgroup_tile = [1], batch_tile = [4], outer_tile = [1], + thread_tile = [1], element_tile = [8], + subgroup_strides = [0], thread_strides = [0]> + +#layout_64 = #iree_vector_ext.nested_layout< + subgroup_tile = [1], batch_tile = [8], outer_tile = [1], + thread_tile = [1], element_tile = [8], + subgroup_strides = [0], thread_strides = [0]> + +// CHECK-LABEL: @conflicting_tile_sizes +func.func @conflicting_tile_sizes(%arg0: tensor<64xf16>, %arg1: tensor<64xf16>) -> tensor<64xf16> { + %a = iree_vector_ext.to_layout %arg0 to layout(#layout_32) : tensor<64xf16> + %b = iree_vector_ext.to_layout %arg1 to layout(#layout_64) : tensor<64xf16> + %empty = tensor.empty() : tensor<64xf16> + // CHECK: linalg.generic + // CHECK-NOT: iree_codegen.vector_tile_sizes + %result = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%a, %b : tensor<64xf16>, tensor<64xf16>) outs(%empty : tensor<64xf16>) { + ^bb0(%in0: f16, %in1: f16, %out: f16): + %add = arith.addf %in0, %in1 : f16 + linalg.yield %add : f16 + } -> tensor<64xf16> + return %result : tensor<64xf16> +} From 939c8406483d2667368cb3f47340b1f0db61e852 Mon Sep 17 00:00:00 2001 From: Lukas Sommer Date: Thu, 19 Mar 2026 10:10:31 +0000 Subject: [PATCH 5/8] Address PR feedback Signed-off-by: Lukas Sommer --- .../iree/compiler/Codegen/Common/BUILD.bazel | 2 +- .../compiler/Codegen/Common/CMakeLists.txt | 2 +- .../Common/MaterializeVectorTileSizes.cpp | 469 ++++++++++++++++++ .../Codegen/Common/VectorTileSizeAnalysis.cpp | 137 +++-- 4 files changed, 538 insertions(+), 72 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/Common/MaterializeVectorTileSizes.cpp diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index 6aabd195093e..7e1a5bc2813f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -136,6 +136,7 @@ iree_compiler_cc_library( "MaterializeEncodingPatterns.cpp", "MaterializeTuningSpecs.cpp", "MaterializeVectorMasking.cpp", + "MaterializeVectorTileSizes.cpp", "MathTransform.cpp", "MemrefCopyToLinalg.cpp", "NormalizeLoopBounds.cpp", @@ -174,7 +175,6 @@ iree_compiler_cc_library( "UnrollAnnotatedLoops.cpp", "UserConfig.cpp", "VectorLayoutAnalysis.cpp", - "VectorTileSizeAnalysis.cpp", "VectorTransferLowering.cpp", "VectorizeMemrefCopy.cpp", "VerifyPipelineConstraints.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index f3ca75204682..3a043febbb2d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -129,6 +129,7 @@ iree_cc_library( "MaterializeEncodingPatterns.cpp" "MaterializeTuningSpecs.cpp" "MaterializeVectorMasking.cpp" + "MaterializeVectorTileSizes.cpp" "MathTransform.cpp" "MemrefCopyToLinalg.cpp" "NormalizeLoopBounds.cpp" @@ -167,7 +168,6 @@ iree_cc_library( "UnrollAnnotatedLoops.cpp" "UserConfig.cpp" "VectorLayoutAnalysis.cpp" - "VectorTileSizeAnalysis.cpp" "VectorTransferLowering.cpp" "VectorizeMemrefCopy.cpp" "VerifyPipelineConstraints.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeVectorTileSizes.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeVectorTileSizes.cpp new file mode 100644 index 000000000000..2b13bf9654b7 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeVectorTileSizes.cpp @@ -0,0 +1,469 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h" + +#include "llvm/Support/DebugLog.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Analysis/DataFlow/Utils.h" +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/SymbolTable.h" + +#define DEBUG_TYPE "iree-codegen-vector-tile-size-analysis" + +// The purpose of this analysis is to propagate information about the +// vector tile size across the operation graph. The vector tile size is +// important information for the vectorization of operations. For example, the +// vector tile size can be used by GenericVectorization to introduce the +// necessary masking in the presence of padding/masking. +// +// The analysis is a bi-directional dataflow analysis building on top of the +// upstream MLIR dataflow analysis framework. To implement the bi-directional +// propagation, it combines a sparse forward analysis and a sparse backward +// analysis in the same solver. +// +// The lattice for the dataflow analysis is shared by both analyses (forward and +// backward). For each N-dimensional ShapedType SSA value, we have a lattice +// element comprising N dimensions, where each dimension is in one of three +// states: uninitialized (bottom), a single tile size value, or overdefined +// (top). The bottom (uninitialized) state is the identity for the merge +// operation. Merging two equal tile sizes for the same dimension is identity. +// Merging two different tile sizes for the same dimension results in the +// overdefined state. Overdefined is absorbing — once a dimension reaches +// overdefined, it stays overdefined. +// +// The lattice is initialized from anchor operations that provide information +// about vector tile size (e.g., `to_layout`). +// +// Forward propagation and backward propagation work similarly: +// - For linalg operations, all available information from operands (forward) or +// results & operands (backward) is mapped to the iteration space based on +// indexing maps and merged into a single lattice state. That lattice state in +// the iteration space is then mapped to each result (forward) or operand +// (backward) based on indexing maps and the mapped state is propagated. +// +// Duplicatable operations such as `tensor.empty`, constants, and generator +// linalg ops (e.g. linalg.fill) are excluded from propagation entirely. CSE +// connects otherwise unrelated compute ops by deduplicating their DPS init +// operands to a single tensor.empty (or similar). To avoid cross-polluting the +// vector tile size of unrelated operations, tile sizes from duplicatable +// operations are never propagated. +// +// For some other operations, no propagation rules are defined on purpose. For +// example, `extract_slice` and `insert_slice` operations are natural boundaries +// of tiling/padding, therefore no information is propagated across them. +// +// After the dataflow solver reaches a fixpoint, the +// MaterializeVectorTileSizesPass materializes the result as a discardable +// attribute. Only dimensions with a single defined (non-overdefined) tile size +// are materialized. Operations where any dimension is uninitialized or +// overdefined do not receive the attribute. + +namespace mlir::iree_compiler { + +using namespace IREE::VectorExt; + +/// Per-dimension tile sizes. Each dimension holds a single tile size value, or +/// one of the sentinel values kUninitialized/kOverdefined. Satisfies the +/// requirements for use as a `Lattice` value type. +class TileSizes { +public: + TileSizes() = default; + explicit TileSizes(unsigned rank) : dims(rank, kUninitialized) {} + + /// Construct from concrete tile sizes (one value per dimension). + TileSizes(ArrayRef sizes) : dims(sizes) {} + + unsigned rank() const { return dims.size(); } + bool empty() const { return dims.empty(); } + const llvm::SmallVector &getDims() const { return dims; } + + int64_t operator[](unsigned i) const { return dims[i]; } + + /// Returns true if all dimensions have a defined (positive) tile size. + bool isDefined() const { + return !empty() && llvm::all_of(dims, [](int64_t v) { + return v != kUninitialized && v != kOverdefined; + }); + } + + /// Returns true if any dimension is overdefined. + bool isOverdefined() const { + return llvm::any_of(dims, [](int64_t v) { return v == kOverdefined; }); + } + + /// Merge tile sizes from `other` into this. Uninitialized is identity. + void merge(const TileSizes &other) { + if (empty()) { + *this = other; + return; + } + if (other.empty()) { + return; + } + assert(rank() == other.rank() && "rank mismatch"); + for (unsigned i = 0; i < rank(); ++i) { + dims[i] = mergeDim(dims[i], other.dims[i]); + } + } + + /// Map from operand space to iteration space via an indexing map. + TileSizes mapToIterationSpace(AffineMap indexingMap) const { + TileSizes result(indexingMap.getNumDims()); + for (unsigned i = 0; i < indexingMap.getNumResults(); ++i) { + auto dimExpr = dyn_cast(indexingMap.getResult(i)); + if (!dimExpr) { + continue; + } + unsigned iterDim = dimExpr.getPosition(); + result.dims[iterDim] = mergeDim(result.dims[iterDim], dims[i]); + } + return result; + } + + /// Map from iteration space to operand space via an indexing map. + /// Returns empty TileSizes if any operand dim can't be determined. + TileSizes mapFromIterationSpace(AffineMap indexingMap) const { + unsigned numResults = indexingMap.getNumResults(); + TileSizes result(numResults); + for (unsigned i = 0; i < numResults; ++i) { + auto dimExpr = dyn_cast(indexingMap.getResult(i)); + if (!dimExpr) { + return {}; + } + unsigned iterDim = dimExpr.getPosition(); + if (iterDim >= rank() || dims[iterDim] == kUninitialized) { + return {}; + } + result.dims[i] = dims[iterDim]; + } + return result; + } + + /// Lattice join: per-dimension merge. Uninitialized is identity. + static TileSizes join(const TileSizes &lhs, const TileSizes &rhs) { + TileSizes result = lhs; + result.merge(rhs); + return result; + } + + /// Lattice meet: same as join (both directions merge the same way). + static TileSizes meet(const TileSizes &lhs, const TileSizes &rhs) { + return join(lhs, rhs); + } + + bool operator==(const TileSizes &rhs) const { return dims == rhs.dims; } + + void print(raw_ostream &os) const { + os << "["; + llvm::interleaveComma(dims, os, [&](int64_t v) { + if (v == kUninitialized) { + os << "?"; + } else if (v == kOverdefined) { + os << ""; + } else { + os << v; + } + }); + os << "]"; + } + +private: + SmallVector dims; + + /// Sentinel values for per-dimension tile size state. Valid tile sizes are + /// always positive. + static constexpr int64_t kUninitialized = 0; + static constexpr int64_t kOverdefined = -1; + + /// Merge a single dimension's tile size value. Returns the merged result. + static int64_t mergeDim(int64_t a, int64_t b) { + if (a == kOverdefined || b == kOverdefined) { + return kOverdefined; + } + if (a == kUninitialized) { + return b; + } + if (b == kUninitialized) { + return a; + } + return a == b ? a : kOverdefined; + } +}; + +/// Returns true if the operation is trivially duplicatable and should not +/// propagate tile sizes across independent consumers. +static bool isDuplicatable(Value val) { + Operation *defOp = val.getDefiningOp(); + if (!defOp) { + return false; + } + if (isa(defOp)) { + return true; + } + if (defOp->hasTrait()) { + return true; + } + // A linalg op that doesn't read any tensor data (e.g., linalg.fill or a + // fill-like linalg.generic broadcasting a scalar) is a generator and + // duplicatable. + if (auto linalgOp = dyn_cast(defOp)) { + if (llvm::none_of(linalgOp->getOpOperands(), [&](OpOperand &operand) { + return isa(operand.get().getType()) && + linalgOp.payloadUsesValueFromOperand(&operand); + })) { + return true; + } + } + return false; +} + +//===----------------------------------------------------------------------===// +// Lattice and analysis definitions +//===----------------------------------------------------------------------===// + +class TileSizeLattice : public dataflow::Lattice { +public: + using Lattice::Lattice; +}; + +/// Read the TileSizes from a lattice, returning empty tile sizes if the lattice +/// value is from a duplicatable operation. +static const TileSizes getTileSizesFor(Value val, + const TileSizeLattice *lattice) { + if (!lattice) { + return {}; + } + auto &tileSizes = lattice->getValue(); + if (tileSizes.empty()) { + return {}; + } + if (isDuplicatable(val)) { + return {}; + } + return tileSizes; +} + +/// Forward analysis: propagates tile sizes from operands to results. +/// Control flow through scf.for/scf.if is handled automatically by the +/// framework via RegionBranchOpInterface. +class TileSizeForwardAnalysis + : public dataflow::SparseForwardDataFlowAnalysis { +public: + using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; + + void setToEntryState(TileSizeLattice *lattice) override { + // Entry state is uninitialized (identity for join). + propagateIfChanged(lattice, lattice->join(TileSizes())); + } + + LogicalResult visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) override { + // to_layout: seed from layout, don't propagate operand forward. + if (auto toLayout = dyn_cast(op)) { + LDBG() << "Anchor: " << toLayout; + TileSizes tileSizes(toLayout.getLayout().getUndistributedShape()); + propagateIfChanged(results[0], results[0]->join(tileSizes)); + return success(); + } + + // Linalg ops: propagate through indexing maps. + if (auto linalgOp = dyn_cast(op)) { + unsigned numLoops = linalgOp.getNumLoops(); + TileSizes iterTileSizes(numLoops); + for (OpOperand &operand : linalgOp->getOpOperands()) { + auto &ts = getTileSizesFor(operand.get(), + operands[operand.getOperandNumber()]); + if (ts.empty()) { + continue; + } + AffineMap map = linalgOp.getMatchingIndexingMap(&operand); + assert(map.getNumDims() == numLoops); + iterTileSizes.merge(ts.mapToIterationSpace(map)); + } + for (unsigned i = 0; i < linalgOp.getNumDpsInits(); ++i) { + OpOperand *init = linalgOp.getDpsInitOperand(i); + AffineMap map = linalgOp.getMatchingIndexingMap(init); + auto resultTileSizes = iterTileSizes.mapFromIterationSpace(map); + if (!resultTileSizes.empty()) { + propagateIfChanged(results[i], results[i]->join(resultTileSizes)); + } + } + return success(); + } + + return success(); + } +}; + +/// Backward analysis: propagates tile sizes from results to operands. +/// Control flow through scf.for/scf.if is handled automatically by the +/// framework via RegionBranchOpInterface. +class TileSizeBackwardAnalysis + : public dataflow::SparseBackwardDataFlowAnalysis { +public: + using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; + + void setToExitState(TileSizeLattice *lattice) override { + // Exit state is uninitialized (identity for meet). + } + + LogicalResult + visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override { + // to_layout is always an anchor op; Propagate tile sizes backward to the + // input. + if (auto toLayout = dyn_cast(op)) { + auto &ts = getTileSizesFor(toLayout.getResult(), results[0]); + if (!ts.empty()) { + TileSizeLattice *inputLattice = operands[0]; + propagateIfChanged(inputLattice, inputLattice->meet(ts)); + } + return success(); + } + + // Linalg ops: propagate through indexing maps. + if (auto linalgOp = dyn_cast(op)) { + unsigned numLoops = linalgOp.getNumLoops(); + TileSizes iterTileSizes(numLoops); + // Gather result tile sizes into iteration space via DPS init maps. + for (auto [result, resultLattice] : + llvm::zip(linalgOp.getOperation()->getResults(), results)) { + auto &ts = getTileSizesFor(result, resultLattice); + if (ts.empty()) { + continue; + } + unsigned resultIdx = cast(result).getResultNumber(); + OpOperand *init = linalgOp.getDpsInitOperand(resultIdx); + AffineMap map = linalgOp.getMatchingIndexingMap(init); + assert(map.getNumDims() == numLoops); + iterTileSizes.merge(ts.mapToIterationSpace(map)); + } + // Gather operand tile sizes into iteration space. + for (OpOperand &operand : linalgOp->getOpOperands()) { + auto &ts = getTileSizesFor(operand.get(), + operands[operand.getOperandNumber()]); + if (ts.empty()) { + continue; + } + AffineMap map = linalgOp.getMatchingIndexingMap(&operand); + assert(map.getNumDims() == numLoops); + iterTileSizes.merge(ts.mapToIterationSpace(map)); + } + // Map iteration space tile sizes back to each operand. + for (OpOperand &operand : linalgOp->getOpOperands()) { + AffineMap map = linalgOp.getMatchingIndexingMap(&operand); + auto operandTileSizes = iterTileSizes.mapFromIterationSpace(map); + if (operandTileSizes.empty()) { + continue; + } + TileSizeLattice *operandLattice = operands[operand.getOperandNumber()]; + propagateIfChanged(operandLattice, + operandLattice->meet(operandTileSizes)); + } + return success(); + } + + return success(); + } + + // Required by the base class. Non-forwarded branch operands (e.g., loop + // bounds, conditions) are scalars irrelevant to tile size propagation. + // Forwarded values (iter_args, yields) are handled by the framework via + // RegionBranchOpInterface. + void visitBranchOperand(OpOperand &operand) override {} + void visitCallOperand(OpOperand &operand) override {} + void + visitNonControlFlowArguments(RegionSuccessor &successor, + ArrayRef arguments) override {} +}; + +//===----------------------------------------------------------------------===// +// Result querying +//===----------------------------------------------------------------------===// + +/// Gather tile sizes into the iteration space of a linalg op by looking up each +/// operand's lattice state in the solver. +static TileSizes getIterationSpaceTileSizes(linalg::LinalgOp linalgOp, + const DataFlowSolver &solver) { + unsigned numLoops = linalgOp.getNumLoops(); + TileSizes iterTileSizes(numLoops); + for (OpOperand &operand : linalgOp->getOpOperands()) { + Value val = operand.get(); + auto *lattice = solver.lookupState(val); + auto &ts = getTileSizesFor(val, lattice); + if (ts.empty()) { + continue; + } + AffineMap map = linalgOp.getMatchingIndexingMap(&operand); + assert(map.getNumDims() == numLoops); + iterTileSizes.merge(ts.mapToIterationSpace(map)); + } + return iterTileSizes; +} + +/// Given a linalg op and the solver, compute per-dimension tile sizes. +/// Returns a vector of one tile size per iteration dimension, or nullopt if +/// any dimension is uninitialized or overdefined. +static std::optional> +getPerDimTileSizes(linalg::LinalgOp linalgOp, const DataFlowSolver &solver) { + TileSizes tileSizes = getIterationSpaceTileSizes(linalgOp, solver); + if (!tileSizes.isDefined()) { + return std::nullopt; + } + assert(tileSizes.rank() == linalgOp.getNumLoops()); + return tileSizes.getDims(); +} + +//===----------------------------------------------------------------------===// +// MaterializeVectorTileSizesPass +//===----------------------------------------------------------------------===// + +#define GEN_PASS_DEF_MATERIALIZEVECTORTILESIZESPASS +#include "iree/compiler/Codegen/Common/Passes.h.inc" + +namespace { + +class MaterializeVectorTileSizesPass final + : public impl::MaterializeVectorTileSizesPassBase< + MaterializeVectorTileSizesPass> { +public: + void runOnOperation() override { + auto funcOp = getOperation(); + + DataFlowSolver solver; + dataflow::loadBaselineAnalyses(solver); + solver.load(); + SymbolTableCollection symbolTable; + solver.load(symbolTable); + + if (failed(solver.initializeAndRun(funcOp))) { + return signalPassFailure(); + } + + funcOp->walk([&](linalg::LinalgOp linalgOp) { + std::optional> perDimSizes = + getPerDimTileSizes(linalgOp, solver); + if (!perDimSizes) { + LDBG() << "Analysis did not determine tile size for" << *linalgOp; + return; + } + + LDBG() << "Materializing tile size on " << *linalgOp; + linalgOp->setAttr( + kVectorTileSizesAttrName, + DenseI64ArrayAttr::get(linalgOp->getContext(), *perDimSizes)); + }); + } +}; + +} // namespace +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp index b41ade6f56c1..11e369ba5575 100644 --- a/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp @@ -74,25 +74,6 @@ namespace mlir::iree_compiler { using namespace IREE::VectorExt; -/// Sentinel values for per-dimension tile size state. Valid tile sizes are -/// always positive. -constexpr int64_t kUninitialized = 0; -constexpr int64_t kOverdefined = -1; - -/// Merge a single dimension's tile size value. Returns the merged result. -static int64_t mergeDim(int64_t a, int64_t b) { - if (a == kOverdefined || b == kOverdefined) { - return kOverdefined; - } - if (a == kUninitialized) { - return b; - } - if (b == kUninitialized) { - return a; - } - return a == b ? a : kOverdefined; -} - /// Per-dimension tile sizes. Each dimension holds a single tile size value, or /// one of the sentinel values kUninitialized/kOverdefined. Satisfies the /// requirements for use as a `Lattice` value type. @@ -101,21 +82,20 @@ class TileSizes { TileSizes() = default; explicit TileSizes(unsigned rank) : dims(rank, kUninitialized) {} - /// Construct from a single concrete tile size (one value per dimension). - static TileSizes fromSizes(ArrayRef sizes) { - TileSizes result; - result.dims.assign(sizes.begin(), sizes.end()); - return result; - } + /// Construct from concrete tile sizes (one value per dimension). + TileSizes(ArrayRef sizes) : dims(sizes) {} unsigned rank() const { return dims.size(); } bool empty() const { return dims.empty(); } + const llvm::SmallVector &getDims() const { return dims; } int64_t operator[](unsigned i) const { return dims[i]; } /// Returns true if all dimensions have a defined (positive) tile size. bool isDefined() const { - return !empty() && llvm::all_of(dims, [](int64_t v) { return v > 0; }); + return !empty() && llvm::all_of(dims, [](int64_t v) { + return v != kUninitialized && v != kOverdefined; + }); } /// Returns true if any dimension is overdefined. @@ -139,9 +119,8 @@ class TileSizes { } /// Map from operand space to iteration space via an indexing map. - TileSizes mapToIterationSpace(AffineMap indexingMap, - unsigned numLoops) const { - TileSizes result(numLoops); + TileSizes mapToIterationSpace(AffineMap indexingMap) const { + TileSizes result(indexingMap.getNumDims()); for (unsigned i = 0; i < indexingMap.getNumResults(); ++i) { auto dimExpr = dyn_cast(indexingMap.getResult(i)); if (!dimExpr) { @@ -192,7 +171,7 @@ class TileSizes { if (v == kUninitialized) { os << "?"; } else if (v == kOverdefined) { - os << "T"; + os << ""; } else { os << v; } @@ -202,6 +181,25 @@ class TileSizes { private: SmallVector dims; + + /// Sentinel values for per-dimension tile size state. Valid tile sizes are + /// always positive. + static constexpr int64_t kUninitialized = 0; + static constexpr int64_t kOverdefined = -1; + + /// Merge a single dimension's tile size value. Returns the merged result. + static int64_t mergeDim(int64_t a, int64_t b) { + if (a == kOverdefined || b == kOverdefined) { + return kOverdefined; + } + if (a == kUninitialized) { + return b; + } + if (b == kUninitialized) { + return a; + } + return a == b ? a : kOverdefined; + } }; /// Returns true if the operation is trivially duplicatable and should not @@ -242,18 +240,17 @@ class TileSizeLattice : public dataflow::Lattice { /// Read the TileSizes from a lattice, returning empty tile sizes if the lattice /// value is from a duplicatable operation. -static const TileSizes &getTileSizesFor(Value val, - const TileSizeLattice *lattice) { - static const TileSizes empty; +static const TileSizes getTileSizesFor(Value val, + const TileSizeLattice *lattice) { if (!lattice) { - return empty; + return {}; } auto &tileSizes = lattice->getValue(); if (tileSizes.empty()) { - return empty; + return {}; } if (isDuplicatable(val)) { - return empty; + return {}; } return tileSizes; } @@ -277,28 +274,28 @@ class TileSizeForwardAnalysis // to_layout: seed from layout, don't propagate operand forward. if (auto toLayout = dyn_cast(op)) { LDBG() << "Anchor: " << toLayout; - auto tileSizes = - TileSizes::fromSizes(toLayout.getLayout().getUndistributedShape()); + TileSizes tileSizes(toLayout.getLayout().getUndistributedShape()); propagateIfChanged(results[0], results[0]->join(tileSizes)); return success(); } - // linalg.generic: propagate through indexing maps. - if (auto genericOp = dyn_cast(op)) { - unsigned numLoops = genericOp.getNumLoops(); + // Linalg ops: propagate through indexing maps. + if (auto linalgOp = dyn_cast(op)) { + unsigned numLoops = linalgOp.getNumLoops(); TileSizes iterTileSizes(numLoops); - for (OpOperand &operand : genericOp->getOpOperands()) { + for (OpOperand &operand : linalgOp->getOpOperands()) { auto &ts = getTileSizesFor(operand.get(), operands[operand.getOperandNumber()]); if (ts.empty()) { continue; } - AffineMap map = genericOp.getMatchingIndexingMap(&operand); - iterTileSizes.merge(ts.mapToIterationSpace(map, numLoops)); + AffineMap map = linalgOp.getMatchingIndexingMap(&operand); + assert(map.getNumDims() == numLoops); + iterTileSizes.merge(ts.mapToIterationSpace(map)); } - for (unsigned i = 0; i < genericOp.getNumDpsInits(); ++i) { - OpOperand *init = genericOp.getDpsInitOperand(i); - AffineMap map = genericOp.getMatchingIndexingMap(init); + for (unsigned i = 0; i < linalgOp.getNumDpsInits(); ++i) { + OpOperand *init = linalgOp.getDpsInitOperand(i); + AffineMap map = linalgOp.getMatchingIndexingMap(init); auto resultTileSizes = iterTileSizes.mapFromIterationSpace(map); if (!resultTileSizes.empty()) { propagateIfChanged(results[i], results[i]->join(resultTileSizes)); @@ -339,7 +336,8 @@ class TileSizeBackwardAnalysis LogicalResult visitOperation(Operation *op, ArrayRef operands, ArrayRef results) override { - // to_layout: propagate result tile sizes backward to input. + // to_layout is always an anchor op; Propagate tile sizes backward to the + // input. if (auto toLayout = dyn_cast(op)) { auto &ts = getTileSizesFor(toLayout.getResult(), results[0]); if (!ts.empty()) { @@ -349,35 +347,37 @@ class TileSizeBackwardAnalysis return success(); } - // linalg.generic: propagate through indexing maps. - if (auto genericOp = dyn_cast(op)) { - unsigned numLoops = genericOp.getNumLoops(); + // Linalg ops: propagate through indexing maps. + if (auto linalgOp = dyn_cast(op)) { + unsigned numLoops = linalgOp.getNumLoops(); TileSizes iterTileSizes(numLoops); // Gather result tile sizes into iteration space via DPS init maps. for (auto [result, resultLattice] : - llvm::zip(genericOp.getResults(), results)) { + llvm::zip(linalgOp.getOperation()->getResults(), results)) { auto &ts = getTileSizesFor(result, resultLattice); if (ts.empty()) { continue; } unsigned resultIdx = cast(result).getResultNumber(); - OpOperand *init = genericOp.getDpsInitOperand(resultIdx); - AffineMap map = genericOp.getMatchingIndexingMap(init); - iterTileSizes.merge(ts.mapToIterationSpace(map, numLoops)); + OpOperand *init = linalgOp.getDpsInitOperand(resultIdx); + AffineMap map = linalgOp.getMatchingIndexingMap(init); + assert(map.getNumDims() == numLoops); + iterTileSizes.merge(ts.mapToIterationSpace(map)); } // Gather operand tile sizes into iteration space. - for (OpOperand &operand : genericOp->getOpOperands()) { + for (OpOperand &operand : linalgOp->getOpOperands()) { auto &ts = getTileSizesFor(operand.get(), operands[operand.getOperandNumber()]); if (ts.empty()) { continue; } - AffineMap map = genericOp.getMatchingIndexingMap(&operand); - iterTileSizes.merge(ts.mapToIterationSpace(map, numLoops)); + AffineMap map = linalgOp.getMatchingIndexingMap(&operand); + assert(map.getNumDims() == numLoops); + iterTileSizes.merge(ts.mapToIterationSpace(map)); } // Map iteration space tile sizes back to each operand. - for (OpOperand &operand : genericOp->getOpOperands()) { - AffineMap map = genericOp.getMatchingIndexingMap(&operand); + for (OpOperand &operand : linalgOp->getOpOperands()) { + AffineMap map = linalgOp.getMatchingIndexingMap(&operand); auto operandTileSizes = iterTileSizes.mapFromIterationSpace(map); if (operandTileSizes.empty()) { continue; @@ -438,7 +438,8 @@ static TileSizes getIterationSpaceTileSizes(linalg::LinalgOp linalgOp, continue; } AffineMap map = linalgOp.getMatchingIndexingMap(&operand); - iterTileSizes.merge(ts.mapToIterationSpace(map, numLoops)); + assert(map.getNumDims() == numLoops); + iterTileSizes.merge(ts.mapToIterationSpace(map)); } return iterTileSizes; } @@ -448,17 +449,12 @@ static TileSizes getIterationSpaceTileSizes(linalg::LinalgOp linalgOp, /// any dimension is uninitialized or overdefined. static std::optional> getPerDimTileSizes(linalg::LinalgOp linalgOp, const DataFlowSolver &solver) { - auto tileSizes = getIterationSpaceTileSizes(linalgOp, solver); + TileSizes tileSizes = getIterationSpaceTileSizes(linalgOp, solver); if (!tileSizes.isDefined()) { return std::nullopt; } - unsigned numLoops = linalgOp.getNumLoops(); - - SmallVector results; - for (unsigned i = 0; i < numLoops; ++i) { - results.push_back(tileSizes[i]); - } - return results; + assert(tileSizes.rank() == linalgOp.getNumLoops()); + return tileSizes.getDims(); } //===----------------------------------------------------------------------===// @@ -488,7 +484,8 @@ class MaterializeVectorTileSizesPass final } funcOp->walk([&](linalg::LinalgOp linalgOp) { - auto perDimSizes = getPerDimTileSizes(linalgOp, solver); + std::optional> perDimSizes = + getPerDimTileSizes(linalgOp, solver); if (!perDimSizes) { return; } From 8e9a9aa943f8223d803a845c4ca7d4e79bd0c246 Mon Sep 17 00:00:00 2001 From: Lukas Sommer Date: Thu, 19 Mar 2026 17:07:28 +0000 Subject: [PATCH 6/8] Delete old implementation file Signed-off-by: Lukas Sommer --- .../Codegen/Common/VectorTileSizeAnalysis.cpp | 502 ------------------ 1 file changed, 502 deletions(-) delete mode 100644 compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp deleted file mode 100644 index 11e369ba5575..000000000000 --- a/compiler/src/iree/compiler/Codegen/Common/VectorTileSizeAnalysis.cpp +++ /dev/null @@ -1,502 +0,0 @@ -// Copyright 2026 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Codegen/Common/Passes.h" -#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" -#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h" - -#include "llvm/Support/DebugLog.h" -#include "mlir/Analysis/DataFlow/SparseAnalysis.h" -#include "mlir/Analysis/DataFlow/Utils.h" -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/SymbolTable.h" - -#define DEBUG_TYPE "iree-codegen-vector-tile-size-analysis" - -// The purpose of this analysis is to propagate information about the -// vector tile size across the operation graph. The vector tile size is -// important information for the vectorization of operations. For example, the -// vector tile size can be used by GenericVectorization to introduce the -// necessary masking in the presence of padding/masking. -// -// The analysis is a bi-directional dataflow analysis building on top of the -// upstream MLIR dataflow analysis framework. To implement the bi-directional -// propagation, it combines a sparse forward analysis and a sparse backward -// analysis in the same solver. -// -// The lattice for the dataflow analysis is shared by both analyses (forward and -// backward). For each N-dimensional ShapedType SSA value, we have a lattice -// element comprising N dimensions, where each dimension is in one of three -// states: uninitialized (bottom), a single tile size value, or overdefined -// (top). The bottom (uninitialized) state is the identity for the merge -// operation. Merging two equal tile sizes for the same dimension is identity. -// Merging two different tile sizes for the same dimension results in the -// overdefined state. Overdefined is absorbing — once a dimension reaches -// overdefined, it stays overdefined. -// -// The lattice is initialized from anchor operations that provide information -// about vector tile size (e.g., `to_layout`). -// -// Forward propagation and backward propagation work similarly: -// - For elementwise operations, tile sizes from the different operands -// (forward) or results (backwards) are merged. The merged lattice state is -// then propagated to all results (forward) or operands (backward). -// - For linalg.generic operations, all available information from operands -// (forward) or results & operands (backward) is mapped to the iteration space -// based on indexing maps and merged into a single lattice state. That lattice -// state in the iteration space is then mapped to each result (forward) or -// operand (backward) based on indexing maps and the mapped state is -// propagated. -// -// Duplicatable operations such as `tensor.empty`, constants, and generator -// linalg ops (e.g. linalg.fill) are excluded from propagation entirely. CSE -// connects otherwise unrelated compute ops by deduplicating their DPS init -// operands to a single tensor.empty (or similar). To avoid cross-polluting the -// vector tile size of unrelated operations, tile sizes from duplicatable -// operations are never propagated. -// -// For some other operations, no propagation rules are defined on purpose. For -// example, `extract_slice` and `insert_slice` operations are natural boundaries -// of tiling/padding, therefore no information is propagated across them. -// -// After the dataflow solver reaches a fixpoint, the -// MaterializeVectorTileSizesPass materializes the result as a discardable -// attribute. Only dimensions with a single defined (non-overdefined) tile size -// are materialized. Operations where any dimension is uninitialized or -// overdefined do not receive the attribute. - -namespace mlir::iree_compiler { - -using namespace IREE::VectorExt; - -/// Per-dimension tile sizes. Each dimension holds a single tile size value, or -/// one of the sentinel values kUninitialized/kOverdefined. Satisfies the -/// requirements for use as a `Lattice` value type. -class TileSizes { -public: - TileSizes() = default; - explicit TileSizes(unsigned rank) : dims(rank, kUninitialized) {} - - /// Construct from concrete tile sizes (one value per dimension). - TileSizes(ArrayRef sizes) : dims(sizes) {} - - unsigned rank() const { return dims.size(); } - bool empty() const { return dims.empty(); } - const llvm::SmallVector &getDims() const { return dims; } - - int64_t operator[](unsigned i) const { return dims[i]; } - - /// Returns true if all dimensions have a defined (positive) tile size. - bool isDefined() const { - return !empty() && llvm::all_of(dims, [](int64_t v) { - return v != kUninitialized && v != kOverdefined; - }); - } - - /// Returns true if any dimension is overdefined. - bool isOverdefined() const { - return llvm::any_of(dims, [](int64_t v) { return v == kOverdefined; }); - } - - /// Merge tile sizes from `other` into this. Uninitialized is identity. - void merge(const TileSizes &other) { - if (empty()) { - *this = other; - return; - } - if (other.empty()) { - return; - } - assert(rank() == other.rank() && "rank mismatch"); - for (unsigned i = 0; i < rank(); ++i) { - dims[i] = mergeDim(dims[i], other.dims[i]); - } - } - - /// Map from operand space to iteration space via an indexing map. - TileSizes mapToIterationSpace(AffineMap indexingMap) const { - TileSizes result(indexingMap.getNumDims()); - for (unsigned i = 0; i < indexingMap.getNumResults(); ++i) { - auto dimExpr = dyn_cast(indexingMap.getResult(i)); - if (!dimExpr) { - continue; - } - unsigned iterDim = dimExpr.getPosition(); - result.dims[iterDim] = mergeDim(result.dims[iterDim], dims[i]); - } - return result; - } - - /// Map from iteration space to operand space via an indexing map. - /// Returns empty TileSizes if any operand dim can't be determined. - TileSizes mapFromIterationSpace(AffineMap indexingMap) const { - unsigned numResults = indexingMap.getNumResults(); - TileSizes result(numResults); - for (unsigned i = 0; i < numResults; ++i) { - auto dimExpr = dyn_cast(indexingMap.getResult(i)); - if (!dimExpr) { - return {}; - } - unsigned iterDim = dimExpr.getPosition(); - if (iterDim >= rank() || dims[iterDim] == kUninitialized) { - return {}; - } - result.dims[i] = dims[iterDim]; - } - return result; - } - - /// Lattice join: per-dimension merge. Uninitialized is identity. - static TileSizes join(const TileSizes &lhs, const TileSizes &rhs) { - TileSizes result = lhs; - result.merge(rhs); - return result; - } - - /// Lattice meet: same as join (both directions merge the same way). - static TileSizes meet(const TileSizes &lhs, const TileSizes &rhs) { - return join(lhs, rhs); - } - - bool operator==(const TileSizes &rhs) const { return dims == rhs.dims; } - - void print(raw_ostream &os) const { - os << "["; - llvm::interleaveComma(dims, os, [&](int64_t v) { - if (v == kUninitialized) { - os << "?"; - } else if (v == kOverdefined) { - os << ""; - } else { - os << v; - } - }); - os << "]"; - } - -private: - SmallVector dims; - - /// Sentinel values for per-dimension tile size state. Valid tile sizes are - /// always positive. - static constexpr int64_t kUninitialized = 0; - static constexpr int64_t kOverdefined = -1; - - /// Merge a single dimension's tile size value. Returns the merged result. - static int64_t mergeDim(int64_t a, int64_t b) { - if (a == kOverdefined || b == kOverdefined) { - return kOverdefined; - } - if (a == kUninitialized) { - return b; - } - if (b == kUninitialized) { - return a; - } - return a == b ? a : kOverdefined; - } -}; - -/// Returns true if the operation is trivially duplicatable and should not -/// propagate tile sizes across independent consumers. -static bool isDuplicatable(Value val) { - Operation *defOp = val.getDefiningOp(); - if (!defOp) { - return false; - } - if (isa(defOp)) { - return true; - } - if (defOp->hasTrait()) { - return true; - } - // A linalg op that doesn't read any tensor data (e.g., linalg.fill or a - // fill-like linalg.generic broadcasting a scalar) is a generator and - // duplicatable. - if (auto linalgOp = dyn_cast(defOp)) { - if (llvm::none_of(linalgOp->getOpOperands(), [&](OpOperand &operand) { - return isa(operand.get().getType()) && - linalgOp.payloadUsesValueFromOperand(&operand); - })) { - return true; - } - } - return false; -} - -//===----------------------------------------------------------------------===// -// Lattice and analysis definitions -//===----------------------------------------------------------------------===// - -class TileSizeLattice : public dataflow::Lattice { -public: - using Lattice::Lattice; -}; - -/// Read the TileSizes from a lattice, returning empty tile sizes if the lattice -/// value is from a duplicatable operation. -static const TileSizes getTileSizesFor(Value val, - const TileSizeLattice *lattice) { - if (!lattice) { - return {}; - } - auto &tileSizes = lattice->getValue(); - if (tileSizes.empty()) { - return {}; - } - if (isDuplicatable(val)) { - return {}; - } - return tileSizes; -} - -/// Forward analysis: propagates tile sizes from operands to results. -/// Control flow through scf.for/scf.if is handled automatically by the -/// framework via RegionBranchOpInterface. -class TileSizeForwardAnalysis - : public dataflow::SparseForwardDataFlowAnalysis { -public: - using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; - - void setToEntryState(TileSizeLattice *lattice) override { - // Entry state is uninitialized (identity for join). - propagateIfChanged(lattice, lattice->join(TileSizes())); - } - - LogicalResult visitOperation(Operation *op, - ArrayRef operands, - ArrayRef results) override { - // to_layout: seed from layout, don't propagate operand forward. - if (auto toLayout = dyn_cast(op)) { - LDBG() << "Anchor: " << toLayout; - TileSizes tileSizes(toLayout.getLayout().getUndistributedShape()); - propagateIfChanged(results[0], results[0]->join(tileSizes)); - return success(); - } - - // Linalg ops: propagate through indexing maps. - if (auto linalgOp = dyn_cast(op)) { - unsigned numLoops = linalgOp.getNumLoops(); - TileSizes iterTileSizes(numLoops); - for (OpOperand &operand : linalgOp->getOpOperands()) { - auto &ts = getTileSizesFor(operand.get(), - operands[operand.getOperandNumber()]); - if (ts.empty()) { - continue; - } - AffineMap map = linalgOp.getMatchingIndexingMap(&operand); - assert(map.getNumDims() == numLoops); - iterTileSizes.merge(ts.mapToIterationSpace(map)); - } - for (unsigned i = 0; i < linalgOp.getNumDpsInits(); ++i) { - OpOperand *init = linalgOp.getDpsInitOperand(i); - AffineMap map = linalgOp.getMatchingIndexingMap(init); - auto resultTileSizes = iterTileSizes.mapFromIterationSpace(map); - if (!resultTileSizes.empty()) { - propagateIfChanged(results[i], results[i]->join(resultTileSizes)); - } - } - return success(); - } - - // Elementwise ops: propagate to all results. - if (OpTrait::hasElementwiseMappableTraits(op)) { - TileSizes combined; - for (auto [operandLattice, operandVal] : - llvm::zip(operands, op->getOperands())) { - combined.merge(getTileSizesFor(operandVal, operandLattice)); - } - for (TileSizeLattice *result : results) { - propagateIfChanged(result, result->join(combined)); - } - return success(); - } - - return success(); - } -}; - -/// Backward analysis: propagates tile sizes from results to operands. -/// Control flow through scf.for/scf.if is handled automatically by the -/// framework via RegionBranchOpInterface. -class TileSizeBackwardAnalysis - : public dataflow::SparseBackwardDataFlowAnalysis { -public: - using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; - - void setToExitState(TileSizeLattice *lattice) override { - // Exit state is uninitialized (identity for meet). - } - - LogicalResult - visitOperation(Operation *op, ArrayRef operands, - ArrayRef results) override { - // to_layout is always an anchor op; Propagate tile sizes backward to the - // input. - if (auto toLayout = dyn_cast(op)) { - auto &ts = getTileSizesFor(toLayout.getResult(), results[0]); - if (!ts.empty()) { - TileSizeLattice *inputLattice = operands[0]; - propagateIfChanged(inputLattice, inputLattice->meet(ts)); - } - return success(); - } - - // Linalg ops: propagate through indexing maps. - if (auto linalgOp = dyn_cast(op)) { - unsigned numLoops = linalgOp.getNumLoops(); - TileSizes iterTileSizes(numLoops); - // Gather result tile sizes into iteration space via DPS init maps. - for (auto [result, resultLattice] : - llvm::zip(linalgOp.getOperation()->getResults(), results)) { - auto &ts = getTileSizesFor(result, resultLattice); - if (ts.empty()) { - continue; - } - unsigned resultIdx = cast(result).getResultNumber(); - OpOperand *init = linalgOp.getDpsInitOperand(resultIdx); - AffineMap map = linalgOp.getMatchingIndexingMap(init); - assert(map.getNumDims() == numLoops); - iterTileSizes.merge(ts.mapToIterationSpace(map)); - } - // Gather operand tile sizes into iteration space. - for (OpOperand &operand : linalgOp->getOpOperands()) { - auto &ts = getTileSizesFor(operand.get(), - operands[operand.getOperandNumber()]); - if (ts.empty()) { - continue; - } - AffineMap map = linalgOp.getMatchingIndexingMap(&operand); - assert(map.getNumDims() == numLoops); - iterTileSizes.merge(ts.mapToIterationSpace(map)); - } - // Map iteration space tile sizes back to each operand. - for (OpOperand &operand : linalgOp->getOpOperands()) { - AffineMap map = linalgOp.getMatchingIndexingMap(&operand); - auto operandTileSizes = iterTileSizes.mapFromIterationSpace(map); - if (operandTileSizes.empty()) { - continue; - } - TileSizeLattice *operandLattice = operands[operand.getOperandNumber()]; - propagateIfChanged(operandLattice, - operandLattice->meet(operandTileSizes)); - } - return success(); - } - - // Elementwise ops: propagate to all operands. - if (OpTrait::hasElementwiseMappableTraits(op)) { - TileSizes combined; - for (auto [resultVal, resultLattice] : - llvm::zip(op->getResults(), results)) { - combined.merge(getTileSizesFor(resultVal, resultLattice)); - } - for (auto [operandLattice, operandVal] : - llvm::zip(operands, op->getOperands())) { - if (!isa(operandVal.getType())) { - continue; - } - propagateIfChanged(operandLattice, operandLattice->meet(combined)); - } - return success(); - } - - return success(); - } - - // Required by the base class. Non-forwarded branch operands (e.g., loop - // bounds, conditions) are scalars irrelevant to tile size propagation. - // Forwarded values (iter_args, yields) are handled by the framework via - // RegionBranchOpInterface. - void visitBranchOperand(OpOperand &operand) override {} - void visitCallOperand(OpOperand &operand) override {} - void - visitNonControlFlowArguments(RegionSuccessor &successor, - ArrayRef arguments) override {} -}; - -//===----------------------------------------------------------------------===// -// Result querying -//===----------------------------------------------------------------------===// - -/// Gather tile sizes into the iteration space of a linalg op by looking up each -/// operand's lattice state in the solver. -static TileSizes getIterationSpaceTileSizes(linalg::LinalgOp linalgOp, - const DataFlowSolver &solver) { - unsigned numLoops = linalgOp.getNumLoops(); - TileSizes iterTileSizes(numLoops); - for (OpOperand &operand : linalgOp->getOpOperands()) { - Value val = operand.get(); - auto *lattice = solver.lookupState(val); - auto &ts = getTileSizesFor(val, lattice); - if (ts.empty()) { - continue; - } - AffineMap map = linalgOp.getMatchingIndexingMap(&operand); - assert(map.getNumDims() == numLoops); - iterTileSizes.merge(ts.mapToIterationSpace(map)); - } - return iterTileSizes; -} - -/// Given a linalg op and the solver, compute per-dimension tile sizes. -/// Returns a vector of one tile size per iteration dimension, or nullopt if -/// any dimension is uninitialized or overdefined. -static std::optional> -getPerDimTileSizes(linalg::LinalgOp linalgOp, const DataFlowSolver &solver) { - TileSizes tileSizes = getIterationSpaceTileSizes(linalgOp, solver); - if (!tileSizes.isDefined()) { - return std::nullopt; - } - assert(tileSizes.rank() == linalgOp.getNumLoops()); - return tileSizes.getDims(); -} - -//===----------------------------------------------------------------------===// -// MaterializeVectorTileSizesPass -//===----------------------------------------------------------------------===// - -#define GEN_PASS_DEF_MATERIALIZEVECTORTILESIZESPASS -#include "iree/compiler/Codegen/Common/Passes.h.inc" - -namespace { - -class MaterializeVectorTileSizesPass final - : public impl::MaterializeVectorTileSizesPassBase< - MaterializeVectorTileSizesPass> { -public: - void runOnOperation() override { - auto funcOp = getOperation(); - - DataFlowSolver solver; - dataflow::loadBaselineAnalyses(solver); - solver.load(); - SymbolTableCollection symbolTable; - solver.load(symbolTable); - - if (failed(solver.initializeAndRun(funcOp))) { - return signalPassFailure(); - } - - funcOp->walk([&](linalg::LinalgOp linalgOp) { - std::optional> perDimSizes = - getPerDimTileSizes(linalgOp, solver); - if (!perDimSizes) { - return; - } - - LDBG() << "Materializing tile size on " << *linalgOp; - linalgOp->setAttr( - kVectorTileSizesAttrName, - DenseI64ArrayAttr::get(linalgOp->getContext(), *perDimSizes)); - }); - } -}; - -} // namespace -} // namespace mlir::iree_compiler From 93e18b5d8221d5eb60a4d8a611a3f825a2e2c9cb Mon Sep 17 00:00:00 2001 From: Lukas Sommer Date: Mon, 23 Mar 2026 08:32:45 +0000 Subject: [PATCH 7/8] Address more PR feedback Signed-off-by: Lukas Sommer --- .../Common/MaterializeVectorTileSizes.cpp | 93 ++++++++----------- .../test/materialize_vector_tile_sizes.mlir | 2 +- 2 files changed, 41 insertions(+), 54 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeVectorTileSizes.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeVectorTileSizes.cpp index 2b13bf9654b7..8d2f30405f21 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeVectorTileSizes.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeVectorTileSizes.cpp @@ -16,7 +16,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/SymbolTable.h" -#define DEBUG_TYPE "iree-codegen-vector-tile-size-analysis" +#define DEBUG_TYPE "iree-codegen-materialize-vector-tile-sizes" // The purpose of this analysis is to propagate information about the // vector tile size across the operation graph. The vector tile size is @@ -83,11 +83,12 @@ class TileSizes { unsigned rank() const { return dims.size(); } bool empty() const { return dims.empty(); } - const llvm::SmallVector &getDims() const { return dims; } + const ArrayRef getDims() const { return dims; } int64_t operator[](unsigned i) const { return dims[i]; } - /// Returns true if all dimensions have a defined (positive) tile size. + /// Returns true if the tile sizes are non-empty and every dimension has a + /// concrete tile size (not uninitialized or overdefined). bool isDefined() const { return !empty() && llvm::all_of(dims, [](int64_t v) { return v != kUninitialized && v != kOverdefined; @@ -236,12 +237,11 @@ class TileSizeLattice : public dataflow::Lattice { /// Read the TileSizes from a lattice, returning empty tile sizes if the lattice /// value is from a duplicatable operation. -static const TileSizes getTileSizesFor(Value val, - const TileSizeLattice *lattice) { +static TileSizes getTileSizesFor(Value val, const TileSizeLattice *lattice) { if (!lattice) { return {}; } - auto &tileSizes = lattice->getValue(); + const TileSizes &tileSizes = lattice->getValue(); if (tileSizes.empty()) { return {}; } @@ -280,22 +280,20 @@ class TileSizeForwardAnalysis unsigned numLoops = linalgOp.getNumLoops(); TileSizes iterTileSizes(numLoops); for (OpOperand &operand : linalgOp->getOpOperands()) { - auto &ts = getTileSizesFor(operand.get(), - operands[operand.getOperandNumber()]); - if (ts.empty()) { + TileSizes tileSizes = getTileSizesFor( + operand.get(), operands[operand.getOperandNumber()]); + if (tileSizes.empty()) { continue; } AffineMap map = linalgOp.getMatchingIndexingMap(&operand); assert(map.getNumDims() == numLoops); - iterTileSizes.merge(ts.mapToIterationSpace(map)); + iterTileSizes.merge(tileSizes.mapToIterationSpace(map)); } for (unsigned i = 0; i < linalgOp.getNumDpsInits(); ++i) { OpOperand *init = linalgOp.getDpsInitOperand(i); AffineMap map = linalgOp.getMatchingIndexingMap(init); - auto resultTileSizes = iterTileSizes.mapFromIterationSpace(map); - if (!resultTileSizes.empty()) { - propagateIfChanged(results[i], results[i]->join(resultTileSizes)); - } + TileSizes resultTileSizes = iterTileSizes.mapFromIterationSpace(map); + propagateIfChanged(results[i], results[i]->join(resultTileSizes)); } return success(); } @@ -322,11 +320,9 @@ class TileSizeBackwardAnalysis // to_layout is always an anchor op; Propagate tile sizes backward to the // input. if (auto toLayout = dyn_cast(op)) { - auto &ts = getTileSizesFor(toLayout.getResult(), results[0]); - if (!ts.empty()) { - TileSizeLattice *inputLattice = operands[0]; - propagateIfChanged(inputLattice, inputLattice->meet(ts)); - } + TileSizes tileSizes = getTileSizesFor(toLayout.getResult(), results[0]); + TileSizeLattice *inputLattice = operands[0]; + propagateIfChanged(inputLattice, inputLattice->meet(tileSizes)); return success(); } @@ -334,34 +330,34 @@ class TileSizeBackwardAnalysis if (auto linalgOp = dyn_cast(op)) { unsigned numLoops = linalgOp.getNumLoops(); TileSizes iterTileSizes(numLoops); - // Gather result tile sizes into iteration space via DPS init maps. + // Gather result tile sizes into iteration space via init maps. for (auto [result, resultLattice] : - llvm::zip(linalgOp.getOperation()->getResults(), results)) { - auto &ts = getTileSizesFor(result, resultLattice); - if (ts.empty()) { + llvm::zip_equal(linalgOp->getResults(), results)) { + TileSizes tileSizes = getTileSizesFor(result, resultLattice); + if (tileSizes.empty()) { continue; } unsigned resultIdx = cast(result).getResultNumber(); OpOperand *init = linalgOp.getDpsInitOperand(resultIdx); AffineMap map = linalgOp.getMatchingIndexingMap(init); assert(map.getNumDims() == numLoops); - iterTileSizes.merge(ts.mapToIterationSpace(map)); + iterTileSizes.merge(tileSizes.mapToIterationSpace(map)); } // Gather operand tile sizes into iteration space. for (OpOperand &operand : linalgOp->getOpOperands()) { - auto &ts = getTileSizesFor(operand.get(), - operands[operand.getOperandNumber()]); - if (ts.empty()) { + TileSizes tileSizes = getTileSizesFor( + operand.get(), operands[operand.getOperandNumber()]); + if (tileSizes.empty()) { continue; } AffineMap map = linalgOp.getMatchingIndexingMap(&operand); assert(map.getNumDims() == numLoops); - iterTileSizes.merge(ts.mapToIterationSpace(map)); + iterTileSizes.merge(tileSizes.mapToIterationSpace(map)); } // Map iteration space tile sizes back to each operand. for (OpOperand &operand : linalgOp->getOpOperands()) { AffineMap map = linalgOp.getMatchingIndexingMap(&operand); - auto operandTileSizes = iterTileSizes.mapFromIterationSpace(map); + TileSizes operandTileSizes = iterTileSizes.mapFromIterationSpace(map); if (operandTileSizes.empty()) { continue; } @@ -398,31 +394,18 @@ static TileSizes getIterationSpaceTileSizes(linalg::LinalgOp linalgOp, TileSizes iterTileSizes(numLoops); for (OpOperand &operand : linalgOp->getOpOperands()) { Value val = operand.get(); - auto *lattice = solver.lookupState(val); - auto &ts = getTileSizesFor(val, lattice); - if (ts.empty()) { + const TileSizeLattice *lattice = solver.lookupState(val); + TileSizes tileSize = getTileSizesFor(val, lattice); + if (tileSize.empty()) { continue; } AffineMap map = linalgOp.getMatchingIndexingMap(&operand); assert(map.getNumDims() == numLoops); - iterTileSizes.merge(ts.mapToIterationSpace(map)); + iterTileSizes.merge(tileSize.mapToIterationSpace(map)); } return iterTileSizes; } -/// Given a linalg op and the solver, compute per-dimension tile sizes. -/// Returns a vector of one tile size per iteration dimension, or nullopt if -/// any dimension is uninitialized or overdefined. -static std::optional> -getPerDimTileSizes(linalg::LinalgOp linalgOp, const DataFlowSolver &solver) { - TileSizes tileSizes = getIterationSpaceTileSizes(linalgOp, solver); - if (!tileSizes.isDefined()) { - return std::nullopt; - } - assert(tileSizes.rank() == linalgOp.getNumLoops()); - return tileSizes.getDims(); -} - //===----------------------------------------------------------------------===// // MaterializeVectorTileSizesPass //===----------------------------------------------------------------------===// @@ -437,7 +420,7 @@ class MaterializeVectorTileSizesPass final MaterializeVectorTileSizesPass> { public: void runOnOperation() override { - auto funcOp = getOperation(); + FunctionOpInterface funcOp = getOperation(); DataFlowSolver solver; dataflow::loadBaselineAnalyses(solver); @@ -450,17 +433,21 @@ class MaterializeVectorTileSizesPass final } funcOp->walk([&](linalg::LinalgOp linalgOp) { - std::optional> perDimSizes = - getPerDimTileSizes(linalgOp, solver); - if (!perDimSizes) { - LDBG() << "Analysis did not determine tile size for" << *linalgOp; + TileSizes tileSizes = getIterationSpaceTileSizes(linalgOp, solver); + if (tileSizes.isOverdefined()) { + linalgOp.emitOpError() + << "tile size analysis did not determine a valid tile size"; + return; + } + if (!tileSizes.isDefined()) { + LDBG() << "Analysis did not determine tile size for " << *linalgOp; return; } + assert(tileSizes.rank() == linalgOp.getNumLoops()); - LDBG() << "Materializing tile size on " << *linalgOp; linalgOp->setAttr( kVectorTileSizesAttrName, - DenseI64ArrayAttr::get(linalgOp->getContext(), *perDimSizes)); + DenseI64ArrayAttr::get(linalgOp->getContext(), tileSizes.getDims())); }); } }; diff --git a/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir b/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir index 3efe010d150a..1e477e8c55be 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline='builtin.module(any(iree-codegen-materialize-vector-tile-sizes))' --split-input-file %s | FileCheck %s +// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-codegen-materialize-vector-tile-sizes))' --split-input-file %s | FileCheck %s // Elementwise chain from to_layout anchor. From 5f34dd90341740cf2d702e45ac2e1968afc07735 Mon Sep 17 00:00:00 2001 From: Lukas Sommer Date: Tue, 24 Mar 2026 08:01:27 +0000 Subject: [PATCH 8/8] Address PR comments Signed-off-by: Lukas Sommer --- .../Common/MaterializeVectorTileSizes.cpp | 26 +++++++------------ .../test/materialize_vector_tile_sizes.mlir | 1 + 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeVectorTileSizes.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeVectorTileSizes.cpp index 8d2f30405f21..a03d1e536567 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeVectorTileSizes.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeVectorTileSizes.cpp @@ -16,6 +16,11 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/SymbolTable.h" +namespace mlir::iree_compiler { +#define GEN_PASS_DEF_MATERIALIZEVECTORTILESIZESPASS +#include "iree/compiler/Codegen/Common/Passes.h.inc" +} // namespace mlir::iree_compiler + #define DEBUG_TYPE "iree-codegen-materialize-vector-tile-sizes" // The purpose of this analysis is to propagate information about the @@ -83,7 +88,7 @@ class TileSizes { unsigned rank() const { return dims.size(); } bool empty() const { return dims.empty(); } - const ArrayRef getDims() const { return dims; } + ArrayRef getDims() const { return dims; } int64_t operator[](unsigned i) const { return dims[i]; } @@ -118,6 +123,10 @@ class TileSizes { /// Map from operand space to iteration space via an indexing map. TileSizes mapToIterationSpace(AffineMap indexingMap) const { TileSizes result(indexingMap.getNumDims()); + if (empty()) { + // Early return in case this candidate is empty. + return result; + } for (unsigned i = 0; i < indexingMap.getNumResults(); ++i) { auto dimExpr = dyn_cast(indexingMap.getResult(i)); if (!dimExpr) { @@ -282,9 +291,6 @@ class TileSizeForwardAnalysis for (OpOperand &operand : linalgOp->getOpOperands()) { TileSizes tileSizes = getTileSizesFor( operand.get(), operands[operand.getOperandNumber()]); - if (tileSizes.empty()) { - continue; - } AffineMap map = linalgOp.getMatchingIndexingMap(&operand); assert(map.getNumDims() == numLoops); iterTileSizes.merge(tileSizes.mapToIterationSpace(map)); @@ -334,9 +340,6 @@ class TileSizeBackwardAnalysis for (auto [result, resultLattice] : llvm::zip_equal(linalgOp->getResults(), results)) { TileSizes tileSizes = getTileSizesFor(result, resultLattice); - if (tileSizes.empty()) { - continue; - } unsigned resultIdx = cast(result).getResultNumber(); OpOperand *init = linalgOp.getDpsInitOperand(resultIdx); AffineMap map = linalgOp.getMatchingIndexingMap(init); @@ -347,9 +350,6 @@ class TileSizeBackwardAnalysis for (OpOperand &operand : linalgOp->getOpOperands()) { TileSizes tileSizes = getTileSizesFor( operand.get(), operands[operand.getOperandNumber()]); - if (tileSizes.empty()) { - continue; - } AffineMap map = linalgOp.getMatchingIndexingMap(&operand); assert(map.getNumDims() == numLoops); iterTileSizes.merge(tileSizes.mapToIterationSpace(map)); @@ -358,9 +358,6 @@ class TileSizeBackwardAnalysis for (OpOperand &operand : linalgOp->getOpOperands()) { AffineMap map = linalgOp.getMatchingIndexingMap(&operand); TileSizes operandTileSizes = iterTileSizes.mapFromIterationSpace(map); - if (operandTileSizes.empty()) { - continue; - } TileSizeLattice *operandLattice = operands[operand.getOperandNumber()]; propagateIfChanged(operandLattice, operandLattice->meet(operandTileSizes)); @@ -410,9 +407,6 @@ static TileSizes getIterationSpaceTileSizes(linalg::LinalgOp linalgOp, // MaterializeVectorTileSizesPass //===----------------------------------------------------------------------===// -#define GEN_PASS_DEF_MATERIALIZEVECTORTILESIZESPASS -#include "iree/compiler/Codegen/Common/Passes.h.inc" - namespace { class MaterializeVectorTileSizesPass final diff --git a/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir b/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir index 1e477e8c55be..514546bd6341 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir @@ -224,6 +224,7 @@ func.func @contraction_indexing_maps( func.func @scf_if_propagation(%arg0: tensor<512xf32>, %cond: i1) -> tensor<512xf32> { %empty = tensor.empty() : tensor<512xf32> %cst = arith.constant 0.0 : f32 + // CHECK-NOT: iree_codegen.vector_tile_sizes %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<512xf32>) -> tensor<512xf32> %if_result = scf.if %cond -> tensor<512xf32> { %laid_out = iree_vector_ext.to_layout %arg0 to layout(#layout_if) : tensor<512xf32>