diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index 233f2760bf82..7e1a5bc2813f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -136,6 +136,7 @@ iree_compiler_cc_library( "MaterializeEncodingPatterns.cpp", "MaterializeTuningSpecs.cpp", "MaterializeVectorMasking.cpp", + "MaterializeVectorTileSizes.cpp", "MathTransform.cpp", "MemrefCopyToLinalg.cpp", "NormalizeLoopBounds.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index 4535db2e6fde..3a043febbb2d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -129,6 +129,7 @@ iree_cc_library( "MaterializeEncodingPatterns.cpp" "MaterializeTuningSpecs.cpp" "MaterializeVectorMasking.cpp" + "MaterializeVectorTileSizes.cpp" "MathTransform.cpp" "MemrefCopyToLinalg.cpp" "NormalizeLoopBounds.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp index e539bacf0354..d3cd248b4d1e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp @@ -30,8 +30,10 @@ namespace mlir::iree_compiler { #include "iree/compiler/Codegen/Common/Passes.h.inc" namespace { -// Returns the vector sizes from the local lowering config or try to infer them -// from the tensor shapes and tiled loops in the IR. + +// Returns the vector sizes from the local lowering config, materialized +// tile size attributes, or tries to infer them from the tensor shapes and +// tiled loops in the IR. static std::optional getVectorSizes(Operation *op, bool useConfiguredVectorSizes) { // Get vector sizes from the lowering config, if available in the op itself. @@ -80,6 +82,15 @@ getVectorSizes(Operation *op, bool useConfiguredVectorSizes) { LDBG() << "Failed to get configured vector sizes, fall back to inference"; } + // Try to get vector sizes from materialized tile size attribute. + if (auto tileSizesAttr = + op->getAttrOfType(kVectorTileSizesAttrName)) { + LDBG() << "Use vector sizes from materialized tile size attribute"; + SmallVector vectorSizes(tileSizesAttr.asArrayRef()); + SmallVector scalableFlags(vectorSizes.size(), false); + return std::make_pair(vectorSizes, scalableFlags); + } + // Try to infer the vector sizes from the IR. std::optional> vectorSizes; SmallVector scalableFlags; diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeVectorTileSizes.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeVectorTileSizes.cpp new file mode 100644 index 000000000000..a03d1e536567 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeVectorTileSizes.cpp @@ -0,0 +1,450 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h" + +#include "llvm/Support/DebugLog.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Analysis/DataFlow/Utils.h" +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/SymbolTable.h" + +namespace mlir::iree_compiler { +#define GEN_PASS_DEF_MATERIALIZEVECTORTILESIZESPASS +#include "iree/compiler/Codegen/Common/Passes.h.inc" +} // namespace mlir::iree_compiler + +#define DEBUG_TYPE "iree-codegen-materialize-vector-tile-sizes" + +// The purpose of this analysis is to propagate information about the +// vector tile size across the operation graph. The vector tile size is +// important information for the vectorization of operations. For example, the +// vector tile size can be used by GenericVectorization to introduce the +// necessary masking in the presence of padding/masking. +// +// The analysis is a bi-directional dataflow analysis building on top of the +// upstream MLIR dataflow analysis framework. To implement the bi-directional +// propagation, it combines a sparse forward analysis and a sparse backward +// analysis in the same solver. +// +// The lattice for the dataflow analysis is shared by both analyses (forward and +// backward). For each N-dimensional ShapedType SSA value, we have a lattice +// element comprising N dimensions, where each dimension is in one of three +// states: uninitialized (bottom), a single tile size value, or overdefined +// (top). The bottom (uninitialized) state is the identity for the merge +// operation. Merging two equal tile sizes for the same dimension is identity. +// Merging two different tile sizes for the same dimension results in the +// overdefined state. Overdefined is absorbing — once a dimension reaches +// overdefined, it stays overdefined. +// +// The lattice is initialized from anchor operations that provide information +// about vector tile size (e.g., `to_layout`). +// +// Forward propagation and backward propagation work similarly: +// - For linalg operations, all available information from operands (forward) or +// results & operands (backward) is mapped to the iteration space based on +// indexing maps and merged into a single lattice state. That lattice state in +// the iteration space is then mapped to each result (forward) or operand +// (backward) based on indexing maps and the mapped state is propagated. +// +// Duplicatable operations such as `tensor.empty`, constants, and generator +// linalg ops (e.g. linalg.fill) are excluded from propagation entirely. CSE +// connects otherwise unrelated compute ops by deduplicating their DPS init +// operands to a single tensor.empty (or similar). To avoid cross-polluting the +// vector tile size of unrelated operations, tile sizes from duplicatable +// operations are never propagated. +// +// For some other operations, no propagation rules are defined on purpose. For +// example, `extract_slice` and `insert_slice` operations are natural boundaries +// of tiling/padding, therefore no information is propagated across them. +// +// After the dataflow solver reaches a fixpoint, the +// MaterializeVectorTileSizesPass materializes the result as a discardable +// attribute. Only dimensions with a single defined (non-overdefined) tile size +// are materialized. Operations where any dimension is uninitialized or +// overdefined do not receive the attribute. + +namespace mlir::iree_compiler { + +using namespace IREE::VectorExt; + +/// Per-dimension tile sizes. Each dimension holds a single tile size value, or +/// one of the sentinel values kUninitialized/kOverdefined. Satisfies the +/// requirements for use as a `Lattice` value type. +class TileSizes { +public: + TileSizes() = default; + explicit TileSizes(unsigned rank) : dims(rank, kUninitialized) {} + + /// Construct from concrete tile sizes (one value per dimension). + TileSizes(ArrayRef sizes) : dims(sizes) {} + + unsigned rank() const { return dims.size(); } + bool empty() const { return dims.empty(); } + ArrayRef getDims() const { return dims; } + + int64_t operator[](unsigned i) const { return dims[i]; } + + /// Returns true if the tile sizes are non-empty and every dimension has a + /// concrete tile size (not uninitialized or overdefined). + bool isDefined() const { + return !empty() && llvm::all_of(dims, [](int64_t v) { + return v != kUninitialized && v != kOverdefined; + }); + } + + /// Returns true if any dimension is overdefined. + bool isOverdefined() const { + return llvm::any_of(dims, [](int64_t v) { return v == kOverdefined; }); + } + + /// Merge tile sizes from `other` into this. Uninitialized is identity. + void merge(const TileSizes &other) { + if (empty()) { + *this = other; + return; + } + if (other.empty()) { + return; + } + assert(rank() == other.rank() && "rank mismatch"); + for (unsigned i = 0; i < rank(); ++i) { + dims[i] = mergeDim(dims[i], other.dims[i]); + } + } + + /// Map from operand space to iteration space via an indexing map. + TileSizes mapToIterationSpace(AffineMap indexingMap) const { + TileSizes result(indexingMap.getNumDims()); + if (empty()) { + // Early return in case this candidate is empty. + return result; + } + for (unsigned i = 0; i < indexingMap.getNumResults(); ++i) { + auto dimExpr = dyn_cast(indexingMap.getResult(i)); + if (!dimExpr) { + continue; + } + unsigned iterDim = dimExpr.getPosition(); + result.dims[iterDim] = mergeDim(result.dims[iterDim], dims[i]); + } + return result; + } + + /// Map from iteration space to operand space via an indexing map. + /// Returns empty TileSizes if any operand dim can't be determined. + TileSizes mapFromIterationSpace(AffineMap indexingMap) const { + unsigned numResults = indexingMap.getNumResults(); + TileSizes result(numResults); + for (unsigned i = 0; i < numResults; ++i) { + auto dimExpr = dyn_cast(indexingMap.getResult(i)); + if (!dimExpr) { + return {}; + } + unsigned iterDim = dimExpr.getPosition(); + if (iterDim >= rank() || dims[iterDim] == kUninitialized) { + return {}; + } + result.dims[i] = dims[iterDim]; + } + return result; + } + + /// Lattice join: per-dimension merge. Uninitialized is identity. + static TileSizes join(const TileSizes &lhs, const TileSizes &rhs) { + TileSizes result = lhs; + result.merge(rhs); + return result; + } + + /// Lattice meet: same as join (both directions merge the same way). + static TileSizes meet(const TileSizes &lhs, const TileSizes &rhs) { + return join(lhs, rhs); + } + + bool operator==(const TileSizes &rhs) const { return dims == rhs.dims; } + + void print(raw_ostream &os) const { + os << "["; + llvm::interleaveComma(dims, os, [&](int64_t v) { + if (v == kUninitialized) { + os << "?"; + } else if (v == kOverdefined) { + os << ""; + } else { + os << v; + } + }); + os << "]"; + } + +private: + SmallVector dims; + + /// Sentinel values for per-dimension tile size state. Valid tile sizes are + /// always positive. + static constexpr int64_t kUninitialized = 0; + static constexpr int64_t kOverdefined = -1; + + /// Merge a single dimension's tile size value. Returns the merged result. + static int64_t mergeDim(int64_t a, int64_t b) { + if (a == kOverdefined || b == kOverdefined) { + return kOverdefined; + } + if (a == kUninitialized) { + return b; + } + if (b == kUninitialized) { + return a; + } + return a == b ? a : kOverdefined; + } +}; + +/// Returns true if the operation is trivially duplicatable and should not +/// propagate tile sizes across independent consumers. +static bool isDuplicatable(Value val) { + Operation *defOp = val.getDefiningOp(); + if (!defOp) { + return false; + } + if (isa(defOp)) { + return true; + } + if (defOp->hasTrait()) { + return true; + } + // A linalg op that doesn't read any tensor data (e.g., linalg.fill or a + // fill-like linalg.generic broadcasting a scalar) is a generator and + // duplicatable. + if (auto linalgOp = dyn_cast(defOp)) { + if (llvm::none_of(linalgOp->getOpOperands(), [&](OpOperand &operand) { + return isa(operand.get().getType()) && + linalgOp.payloadUsesValueFromOperand(&operand); + })) { + return true; + } + } + return false; +} + +//===----------------------------------------------------------------------===// +// Lattice and analysis definitions +//===----------------------------------------------------------------------===// + +class TileSizeLattice : public dataflow::Lattice { +public: + using Lattice::Lattice; +}; + +/// Read the TileSizes from a lattice, returning empty tile sizes if the lattice +/// value is from a duplicatable operation. +static TileSizes getTileSizesFor(Value val, const TileSizeLattice *lattice) { + if (!lattice) { + return {}; + } + const TileSizes &tileSizes = lattice->getValue(); + if (tileSizes.empty()) { + return {}; + } + if (isDuplicatable(val)) { + return {}; + } + return tileSizes; +} + +/// Forward analysis: propagates tile sizes from operands to results. +/// Control flow through scf.for/scf.if is handled automatically by the +/// framework via RegionBranchOpInterface. +class TileSizeForwardAnalysis + : public dataflow::SparseForwardDataFlowAnalysis { +public: + using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; + + void setToEntryState(TileSizeLattice *lattice) override { + // Entry state is uninitialized (identity for join). + propagateIfChanged(lattice, lattice->join(TileSizes())); + } + + LogicalResult visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) override { + // to_layout: seed from layout, don't propagate operand forward. + if (auto toLayout = dyn_cast(op)) { + LDBG() << "Anchor: " << toLayout; + TileSizes tileSizes(toLayout.getLayout().getUndistributedShape()); + propagateIfChanged(results[0], results[0]->join(tileSizes)); + return success(); + } + + // Linalg ops: propagate through indexing maps. + if (auto linalgOp = dyn_cast(op)) { + unsigned numLoops = linalgOp.getNumLoops(); + TileSizes iterTileSizes(numLoops); + for (OpOperand &operand : linalgOp->getOpOperands()) { + TileSizes tileSizes = getTileSizesFor( + operand.get(), operands[operand.getOperandNumber()]); + AffineMap map = linalgOp.getMatchingIndexingMap(&operand); + assert(map.getNumDims() == numLoops); + iterTileSizes.merge(tileSizes.mapToIterationSpace(map)); + } + for (unsigned i = 0; i < linalgOp.getNumDpsInits(); ++i) { + OpOperand *init = linalgOp.getDpsInitOperand(i); + AffineMap map = linalgOp.getMatchingIndexingMap(init); + TileSizes resultTileSizes = iterTileSizes.mapFromIterationSpace(map); + propagateIfChanged(results[i], results[i]->join(resultTileSizes)); + } + return success(); + } + + return success(); + } +}; + +/// Backward analysis: propagates tile sizes from results to operands. +/// Control flow through scf.for/scf.if is handled automatically by the +/// framework via RegionBranchOpInterface. +class TileSizeBackwardAnalysis + : public dataflow::SparseBackwardDataFlowAnalysis { +public: + using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; + + void setToExitState(TileSizeLattice *lattice) override { + // Exit state is uninitialized (identity for meet). + } + + LogicalResult + visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override { + // to_layout is always an anchor op; Propagate tile sizes backward to the + // input. + if (auto toLayout = dyn_cast(op)) { + TileSizes tileSizes = getTileSizesFor(toLayout.getResult(), results[0]); + TileSizeLattice *inputLattice = operands[0]; + propagateIfChanged(inputLattice, inputLattice->meet(tileSizes)); + return success(); + } + + // Linalg ops: propagate through indexing maps. + if (auto linalgOp = dyn_cast(op)) { + unsigned numLoops = linalgOp.getNumLoops(); + TileSizes iterTileSizes(numLoops); + // Gather result tile sizes into iteration space via init maps. + for (auto [result, resultLattice] : + llvm::zip_equal(linalgOp->getResults(), results)) { + TileSizes tileSizes = getTileSizesFor(result, resultLattice); + unsigned resultIdx = cast(result).getResultNumber(); + OpOperand *init = linalgOp.getDpsInitOperand(resultIdx); + AffineMap map = linalgOp.getMatchingIndexingMap(init); + assert(map.getNumDims() == numLoops); + iterTileSizes.merge(tileSizes.mapToIterationSpace(map)); + } + // Gather operand tile sizes into iteration space. + for (OpOperand &operand : linalgOp->getOpOperands()) { + TileSizes tileSizes = getTileSizesFor( + operand.get(), operands[operand.getOperandNumber()]); + AffineMap map = linalgOp.getMatchingIndexingMap(&operand); + assert(map.getNumDims() == numLoops); + iterTileSizes.merge(tileSizes.mapToIterationSpace(map)); + } + // Map iteration space tile sizes back to each operand. + for (OpOperand &operand : linalgOp->getOpOperands()) { + AffineMap map = linalgOp.getMatchingIndexingMap(&operand); + TileSizes operandTileSizes = iterTileSizes.mapFromIterationSpace(map); + TileSizeLattice *operandLattice = operands[operand.getOperandNumber()]; + propagateIfChanged(operandLattice, + operandLattice->meet(operandTileSizes)); + } + return success(); + } + + return success(); + } + + // Required by the base class. Non-forwarded branch operands (e.g., loop + // bounds, conditions) are scalars irrelevant to tile size propagation. + // Forwarded values (iter_args, yields) are handled by the framework via + // RegionBranchOpInterface. + void visitBranchOperand(OpOperand &operand) override {} + void visitCallOperand(OpOperand &operand) override {} + void + visitNonControlFlowArguments(RegionSuccessor &successor, + ArrayRef arguments) override {} +}; + +//===----------------------------------------------------------------------===// +// Result querying +//===----------------------------------------------------------------------===// + +/// Gather tile sizes into the iteration space of a linalg op by looking up each +/// operand's lattice state in the solver. +static TileSizes getIterationSpaceTileSizes(linalg::LinalgOp linalgOp, + const DataFlowSolver &solver) { + unsigned numLoops = linalgOp.getNumLoops(); + TileSizes iterTileSizes(numLoops); + for (OpOperand &operand : linalgOp->getOpOperands()) { + Value val = operand.get(); + const TileSizeLattice *lattice = solver.lookupState(val); + TileSizes tileSize = getTileSizesFor(val, lattice); + if (tileSize.empty()) { + continue; + } + AffineMap map = linalgOp.getMatchingIndexingMap(&operand); + assert(map.getNumDims() == numLoops); + iterTileSizes.merge(tileSize.mapToIterationSpace(map)); + } + return iterTileSizes; +} + +//===----------------------------------------------------------------------===// +// MaterializeVectorTileSizesPass +//===----------------------------------------------------------------------===// + +namespace { + +class MaterializeVectorTileSizesPass final + : public impl::MaterializeVectorTileSizesPassBase< + MaterializeVectorTileSizesPass> { +public: + void runOnOperation() override { + FunctionOpInterface funcOp = getOperation(); + + DataFlowSolver solver; + dataflow::loadBaselineAnalyses(solver); + solver.load(); + SymbolTableCollection symbolTable; + solver.load(symbolTable); + + if (failed(solver.initializeAndRun(funcOp))) { + return signalPassFailure(); + } + + funcOp->walk([&](linalg::LinalgOp linalgOp) { + TileSizes tileSizes = getIterationSpaceTileSizes(linalgOp, solver); + if (tileSizes.isOverdefined()) { + linalgOp.emitOpError() + << "tile size analysis did not determine a valid tile size"; + return; + } + if (!tileSizes.isDefined()) { + LDBG() << "Analysis did not determine tile size for " << *linalgOp; + return; + } + assert(tileSizes.rank() == linalgOp.getNumLoops()); + + linalgOp->setAttr( + kVectorTileSizesAttrName, + DenseI64ArrayAttr::get(linalgOp->getContext(), tileSizes.getDims())); + }); + } +}; + +} // namespace +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index 64c81e478cc2..60890443b962 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -822,6 +822,16 @@ def MaterializeEncodingIntoNopPass : let summary = "Drop the encodings from tensor types with encodings."; } +def MaterializeVectorTileSizesPass : + InterfacePass<"iree-codegen-materialize-vector-tile-sizes", + "mlir::FunctionOpInterface"> { + let summary = "Propagate vector tile sizes and materialize as attribute"; + let description = [{ + Propagate vector tile sizes from to_layout anchors and materialize them as + discardable attributes on compute ops. + }]; +} + def MaterializeTuningSpecsPass : Pass<"iree-codegen-materialize-tuning-specs", "ModuleOp"> { let summary = "Load tuning spec transform dialect libraries and encode them in the module"; diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel index 83460c922156..7bdf6f727b3c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel @@ -102,6 +102,7 @@ iree_lit_test_suite( "materialize_user_config_from_tuning_spec.mlir", "materialize_user_configs.mlir", "materialize_vector_masking.mlir", + "materialize_vector_tile_sizes.mlir", "math_transform.mlir", "normalize_loop_bounds.mlir", "optimize_tensor_insert_extract_slices.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt index b84a8121a982..5bb574e1f30e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt @@ -97,6 +97,7 @@ iree_lit_test_suite( "materialize_user_config_from_tuning_spec.mlir" "materialize_user_configs.mlir" "materialize_vector_masking.mlir" + "materialize_vector_tile_sizes.mlir" "math_transform.mlir" "normalize_loop_bounds.mlir" "optimize_tensor_insert_extract_slices.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir b/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir new file mode 100644 index 000000000000..514546bd6341 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/test/materialize_vector_tile_sizes.mlir @@ -0,0 +1,282 @@ +// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-codegen-materialize-vector-tile-sizes))' --split-input-file %s | FileCheck %s + +// Elementwise chain from to_layout anchor. + +#layout = #iree_vector_ext.nested_layout< + subgroup_tile = [1], batch_tile = [8], outer_tile = [1], + thread_tile = [1], element_tile = [8], + subgroup_strides = [0], thread_strides = [0]> + +// CHECK-LABEL: @elementwise_from_anchor +func.func @elementwise_from_anchor(%arg0: tensor<63xf16>) -> tensor<63xf16> { + %empty = tensor.empty() : tensor<63xf16> + // CHECK: linalg.generic + // CHECK-SAME: iree_codegen.vector_tile_sizes = array + %0 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%arg0 : tensor<63xf16>) outs(%empty : tensor<63xf16>) { + ^bb0(%in: f16, %out: f16): + %add = arith.addf %in, %in : f16 + linalg.yield %add : f16 + } -> tensor<63xf16> + %1 = iree_vector_ext.to_layout %0 to layout(#layout) : tensor<63xf16> + // CHECK: linalg.generic + // CHECK-SAME: iree_codegen.vector_tile_sizes = array + %2 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%1 : tensor<63xf16>) outs(%empty : tensor<63xf16>) { + ^bb0(%in: f16, %out: f16): + %mul = arith.mulf %in, %in : f16 + linalg.yield %mul : f16 + } -> tensor<63xf16> + return %2 : tensor<63xf16> +} + +// ----- + +// Chain propagation with transpose: tile sizes must propagate through +// generic A's result to generic B, with B using a transposed indexing map. + +#layout_2d = #iree_vector_ext.nested_layout< + subgroup_tile = [1, 1], batch_tile = [1, 8], outer_tile = [1, 1], + thread_tile = [1, 1], element_tile = [8, 8], + subgroup_strides = [0, 0], thread_strides = [0, 0]> + +// CHECK-LABEL: @chain_propagation_transpose +func.func @chain_propagation_transpose( + %arg0: tensor<8x64xf32>, %arg1: tensor<8x64xf32>) -> tensor<64x8xf32> { + %a = iree_vector_ext.to_layout %arg0 to layout(#layout_2d) : tensor<8x64xf32> + %empty_ab = tensor.empty() : tensor<8x64xf32> + // CHECK: linalg.generic + // CHECK-SAME: iree_codegen.vector_tile_sizes = array + %ab = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%a, %arg1 : tensor<8x64xf32>, tensor<8x64xf32>) + outs(%empty_ab : tensor<8x64xf32>) { + ^bb0(%in0: f32, %in1: f32, %out: f32): + %add = arith.addf %in0, %in1 : f32 + linalg.yield %add : f32 + } -> tensor<8x64xf32> + %empty_t = tensor.empty() : tensor<64x8xf32> + // CHECK: linalg.generic + // CHECK-SAME: iree_codegen.vector_tile_sizes = array + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%ab : tensor<8x64xf32>) outs(%empty_t : tensor<64x8xf32>) { + ^bb0(%in: f32, %out: f32): + %neg = arith.negf %in : f32 + linalg.yield %neg : f32 + } -> tensor<64x8xf32> + return %result : tensor<64x8xf32> +} + +// ----- + +// Chain propagation with dynamic shapes: tile sizes propagate the same way +// regardless of whether tensor dimensions are static or dynamic. + +#layout_dyn = #iree_vector_ext.nested_layout< + subgroup_tile = [1, 1], batch_tile = [1, 8], outer_tile = [1, 1], + thread_tile = [1, 1], element_tile = [8, 8], + subgroup_strides = [0, 0], thread_strides = [0, 0]> + +// CHECK-LABEL: @chain_propagation_dynamic +func.func @chain_propagation_dynamic( + %arg0: tensor, %arg1: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %a = iree_vector_ext.to_layout %arg0 to layout(#layout_dyn) : tensor + %empty_ab = tensor.empty(%d0, %d1) : tensor + // CHECK: linalg.generic + // CHECK-SAME: iree_codegen.vector_tile_sizes = array + %ab = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%a, %arg1 : tensor, tensor) + outs(%empty_ab : tensor) { + ^bb0(%in0: f32, %in1: f32, %out: f32): + %add = arith.addf %in0, %in1 : f32 + linalg.yield %add : f32 + } -> tensor + %empty_c = tensor.empty(%d0, %d1) : tensor + // CHECK: linalg.generic + // CHECK-SAME: iree_codegen.vector_tile_sizes = array + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%ab : tensor) outs(%empty_c : tensor) { + ^bb0(%in: f32, %out: f32): + %neg = arith.negf %in : f32 + linalg.yield %neg : f32 + } -> tensor + return %result : tensor +} + +// ----- + +// scf.for propagation through iter_args. +// The to_layout inside the loop should propagate tile sizes to the +// loop iter_args and through the scf.yield. +// The fill-like %init generic is duplicatable, so it should NOT receive +// tile sizes. + +#layout = #iree_vector_ext.nested_layout< + subgroup_tile = [8], batch_tile = [1], outer_tile = [1], + thread_tile = [64], element_tile = [1], + subgroup_strides = [1], thread_strides = [1]> + +// CHECK-LABEL: @scf_for_propagation +func.func @scf_for_propagation(%arg0: tensor<512xf32>, %lb: index, %ub: index, %step: index) -> tensor<512xf32> { + %empty = tensor.empty() : tensor<512xf32> + %cst = arith.constant 0.0 : f32 + // CHECK: linalg.generic + // CHECK-NOT: iree_codegen.vector_tile_sizes + %init = linalg.generic { + indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%cst : f32) outs(%empty : tensor<512xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<512xf32> + %result = scf.for %iv = %lb to %ub step %step iter_args(%iter = %init) -> tensor<512xf32> { + %laid_out = iree_vector_ext.to_layout %iter to layout(#layout) : tensor<512xf32> + // CHECK: linalg.generic + // CHECK-SAME: iree_codegen.vector_tile_sizes = array + %updated = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%laid_out, %arg0 : tensor<512xf32>, tensor<512xf32>) outs(%empty : tensor<512xf32>) { + ^bb0(%in0: f32, %in1: f32, %out: f32): + %add = arith.addf %in0, %in1 : f32 + linalg.yield %add : f32 + } -> tensor<512xf32> + scf.yield %updated : tensor<512xf32> + } + return %result : tensor<512xf32> +} + +// ----- + +#layout_a = #iree_vector_ext.nested_layout< + subgroup_tile = [1], batch_tile = [8], outer_tile = [1], + thread_tile = [1], element_tile = [8], + subgroup_strides = [0], thread_strides = [0]> + +#layout_b = #iree_vector_ext.nested_layout< + subgroup_tile = [8, 1], batch_tile = [1, 8], outer_tile = [1, 1], + thread_tile = [64, 1], element_tile = [1, 8], + subgroup_strides = [1, 0], thread_strides = [1, 0]> + +#layout_c = #iree_vector_ext.nested_layout< + subgroup_tile = [8], batch_tile = [1], outer_tile = [1], + thread_tile = [64], element_tile = [1], + subgroup_strides = [1], thread_strides = [1]> + +// CHECK-LABEL: @contraction_indexing_maps +func.func @contraction_indexing_maps( + %a: tensor<63xf16>, %b: tensor<512x63xf16>, %c: tensor<512xf32>) -> tensor<512xf32> { + %al = iree_vector_ext.to_layout %a to layout(#layout_a) : tensor<63xf16> + %bl = iree_vector_ext.to_layout %b to layout(#layout_b) : tensor<512x63xf16> + %cl = iree_vector_ext.to_layout %c to layout(#layout_c) : tensor<512xf32> + // CHECK: linalg.generic + // CHECK-SAME: iree_codegen.vector_tile_sizes = array + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d1)> + ], + iterator_types = ["reduction", "parallel"] + } ins(%al, %bl : tensor<63xf16>, tensor<512x63xf16>) outs(%cl : tensor<512xf32>) { + ^bb0(%in0: f16, %in1: f16, %out: f32): + %ext0 = arith.extf %in0 : f16 to f32 + %ext1 = arith.extf %in1 : f16 to f32 + %mul = arith.mulf %ext0, %ext1 : f32 + %add = arith.addf %mul, %out : f32 + linalg.yield %add : f32 + } -> tensor<512xf32> + return %result : tensor<512xf32> +} + +// ----- + +// scf.if propagation: tile size from the to_layout inside one branch +// should propagate through the scf.if result to consumers outside. + +#layout_if = #iree_vector_ext.nested_layout< + subgroup_tile = [8], batch_tile = [1], outer_tile = [1], + thread_tile = [64], element_tile = [1], + subgroup_strides = [1], thread_strides = [1]> + +// CHECK-LABEL: @scf_if_propagation +func.func @scf_if_propagation(%arg0: tensor<512xf32>, %cond: i1) -> tensor<512xf32> { + %empty = tensor.empty() : tensor<512xf32> + %cst = arith.constant 0.0 : f32 + // CHECK-NOT: iree_codegen.vector_tile_sizes + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<512xf32>) -> tensor<512xf32> + %if_result = scf.if %cond -> tensor<512xf32> { + %laid_out = iree_vector_ext.to_layout %arg0 to layout(#layout_if) : tensor<512xf32> + scf.yield %laid_out : tensor<512xf32> + } else { + scf.yield %fill : tensor<512xf32> + } + // CHECK: linalg.generic + // CHECK-SAME: iree_codegen.vector_tile_sizes = array + %result = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%if_result : tensor<512xf32>) outs(%empty : tensor<512xf32>) { + ^bb0(%in: f32, %out: f32): + %neg = arith.negf %in : f32 + linalg.yield %neg : f32 + } -> tensor<512xf32> + return %result : tensor<512xf32> +} + +// ----- + +// Conflicting tile sizes: two to_layout ops with different tile sizes for the +// same dimension feed into one generic. The dimension becomes overdefined, so +// no tile size attribute should be materialized. + +#layout_32 = #iree_vector_ext.nested_layout< + subgroup_tile = [1], batch_tile = [4], outer_tile = [1], + thread_tile = [1], element_tile = [8], + subgroup_strides = [0], thread_strides = [0]> + +#layout_64 = #iree_vector_ext.nested_layout< + subgroup_tile = [1], batch_tile = [8], outer_tile = [1], + thread_tile = [1], element_tile = [8], + subgroup_strides = [0], thread_strides = [0]> + +// CHECK-LABEL: @conflicting_tile_sizes +func.func @conflicting_tile_sizes(%arg0: tensor<64xf16>, %arg1: tensor<64xf16>) -> tensor<64xf16> { + %a = iree_vector_ext.to_layout %arg0 to layout(#layout_32) : tensor<64xf16> + %b = iree_vector_ext.to_layout %arg1 to layout(#layout_64) : tensor<64xf16> + %empty = tensor.empty() : tensor<64xf16> + // CHECK: linalg.generic + // CHECK-NOT: iree_codegen.vector_tile_sizes + %result = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%a, %b : tensor<64xf16>, tensor<64xf16>) outs(%empty : tensor<64xf16>) { + ^bb0(%in0: f16, %in1: f16, %out: f16): + %add = arith.addf %in0, %in1 : f16 + linalg.yield %add : f16 + } -> tensor<64xf16> + return %result : tensor<64xf16> +} diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h index c0f5a4752da2..90ea99440bfb 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h @@ -71,6 +71,8 @@ constexpr StringLiteral kSerializedTuningSpecAttrName = constexpr StringLiteral kKernelConfigSpecName = "__kernel_config"; constexpr StringLiteral kUkernelAttrName = "iree_codegen.ukernel"; constexpr StringLiteral kUKernelProviderName = "iree_codegen.ukernel_provider"; +constexpr StringLiteral kVectorTileSizesAttrName = + "iree_codegen.vector_tile_sizes"; //===----------------------------------------------------------------------===// // Helpers for getting/setting iree_codegen.translation_info attribute on a diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index a8a8c7ddfc3c..325914eb2647 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -276,6 +276,9 @@ static void addGPUVectorizationPasses(OpPassManager &funcPassManager, funcPassManager.addPass(IREE::LinalgExt::createDecomposeIm2colPass()); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); + if (enableMasking) { + funcPassManager.addPass(createMaterializeVectorTileSizesPass()); + } // Vectorize. GenericVectorizationPassOptions options; options.vectorizeCopies = vectorizeCopies;