Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
#include "compiler/plugins/target/LLVMCPU/LibraryBuilder.h"
#include "compiler/plugins/target/LLVMCPU/LinkerTool.h"
#include "compiler/plugins/target/LLVMCPU/StaticLibraryGenerator.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUDialect.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
Expand Down
2 changes: 1 addition & 1 deletion compiler/plugins/target/VMVX/VMVXTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUDialect.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/VMVX/Passes.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "iree/builtins/ukernel/exported_bits.h"
#include "iree/compiler/Codegen/Common/CPU/Passes.h"
#include "iree/compiler/Codegen/Common/EncodingUtils.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/CPU/Passes.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ iree_compiler_cc_library(
"IREECPUDialect.cpp",
],
hdrs = [
"IREECPUAttrs.h",
"IREECPUDialect.h",
"IREECPUEnums.h",
"IREECPUTypes.h",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ iree_cc_library(
NAME
IREECPUDialect
HDRS
"IREECPUAttrs.h"
"IREECPUDialect.h"
"IREECPUEnums.h"
"IREECPUTypes.h"
Expand Down
148 changes: 112 additions & 36 deletions compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUEnums.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
Expand Down Expand Up @@ -287,42 +288,116 @@ SmallVector<bool> LoweringConfigAttr::getVectorScalableFlags() const {
// CPU MMA intrinsic layout (MxNxK shape and element types)
//===----------------------------------------------------------------------===//

namespace {

struct CPUOpaqueMmaLayout {
int64_t mSize = 0;
int64_t nSize = 0;
int64_t kSize = 0;
Type aType;
Type bType;
Type cType;
};

} // namespace

static std::tuple<int64_t, int64_t, int64_t>
getMNKShapeFromIntrinsic(MMAIntrinsic intrinsic) {
// Helper function for getIntrinsicSwizzle.
// Allows succinctly specifying the tiles for the common case of row-major tiles
// i.e. when the TileSwizzle merely encodes the 2D tile shape and no further
// expand/transpose.
// This case is very common for CPU MMA intrinsics due to:
// 1. CPU matmul instructions having simple tile layouts.
// 2. The inner_tiled op RHS being transposed, meaning that the row-major RHS
// tile really is a column-major RHS matmul tile.
//
// If you're trying to add support for a new intrinsic that doesn't have a
// row-major tile layout, don't look at this function, go directly to
// getIntrinsicSwizzle.
static std::optional<std::tuple<int64_t, int64_t, int64_t>>
getRowMajorTilesMNKShape(MMAIntrinsic intrinsic) {
using Tuple = std::tuple<int64_t, int64_t, int64_t>;
switch (intrinsic) {
case MMAIntrinsic::None:
return {0, 0, 0};
return Tuple{0, 0, 0};
case MMAIntrinsic::MMA_X86_AVX512_1x8x1_F64_F64:
return {1, 8, 1};
return Tuple{1, 8, 1};
case MMAIntrinsic::MMA_X86_AVX512_1x16x1_F32_F32:
case MMAIntrinsic::MMA_X86_AVX512_1x16x1_F32_F16_CASTF32:
return {1, 16, 1};
return Tuple{1, 16, 1};
case MMAIntrinsic::MMA_X86_AVX512FP16_1x32x1_F16_F16:
return {1, 32, 1};
return Tuple{1, 32, 1};
case MMAIntrinsic::MMA_X86_AVX512BF16_1x16x2_F32_BF16:
case MMAIntrinsic::MMA_X86_AVX512_1x16x2_I32_I16:
case MMAIntrinsic::MMA_X86_AVX512VNNI_1x16x2_I32_I16:
case MMAIntrinsic::MMA_X86_AVX512_1x16x2_I32_I8_CASTI16:
case MMAIntrinsic::MMA_X86_AVX512VNNI_1x16x2_I32_I8_CASTI16:
return {1, 16, 2};
return Tuple{1, 16, 2};
default:
return {0, 0, 0};
return {};
}
}

Codegen::TileSwizzle getIntrinsicSwizzle(IREE::CPU::MMAIntrinsic mma,
int operandIdx) {
using TileSwizzle = Codegen::TileSwizzle;
using Kind = TileSwizzle::Dim::Kind;

auto maybeMnkTuple = getRowMajorTilesMNKShape(mma);
if (!maybeMnkTuple) {
// Whenever one adds support for a new intrinsic that doesn't have a
// row-major tile layout, new logic goes here.
assert(false && "Non-row-major-tile intrinsics not yet implemented.");
return TileSwizzle();
}
auto [mSize, nSize, kSize] = *maybeMnkTuple;
TileSwizzle swizzle;
swizzle.expandShape.resize(2);
auto expandIfNonUnit = [](TileSwizzle &swizzle, int dim, int size) {
if (size > 1) {
Codegen::expand(swizzle, dim, TileSwizzle::Dim{Kind::Internal, size});
}
};

if (operandIdx == 0) {
constexpr int M = 0, K = 1;
expandIfNonUnit(swizzle, K, kSize);
expandIfNonUnit(swizzle, M, mSize);
} else if (operandIdx == 1) {
constexpr int N = 0, K = 1;
expandIfNonUnit(swizzle, K, kSize);
expandIfNonUnit(swizzle, N, nSize);
} else {
constexpr int N = 0, M = 1;
expandIfNonUnit(swizzle, N, nSize);
expandIfNonUnit(swizzle, M, mSize);
}
return swizzle;
}

Codegen::TileSwizzle getSwizzle(IREE::CPU::DataTiledMMAAttr mma,
int operandIdx) {
using TileSwizzle = Codegen::TileSwizzle;
using Kind = TileSwizzle::Dim::Kind;
TileSwizzle swizzle = getIntrinsicSwizzle(mma.getIntrinsic(), operandIdx);
TileSwizzle::Dim intrinsicsM = {Kind::CrossIntrinsic, mma.getIntrinsicsM()};
TileSwizzle::Dim intrinsicsN = {Kind::CrossIntrinsic, mma.getIntrinsicsN()};
TileSwizzle::Dim intrinsicsK = {Kind::CrossIntrinsic, mma.getIntrinsicsK()};
// LHS: (M, K); RHS: (K, N); Acc: (M, N).
if (operandIdx == 0) {
constexpr int M = 0, K = 1;
if (intrinsicsK.size > 1) {
Codegen::expand(swizzle, K, intrinsicsK);
}
if (intrinsicsM.size > 1) {
Codegen::expand(swizzle, M, intrinsicsM);
}
} else if (operandIdx == 1) {
constexpr int N = 0, K = 1;
if (intrinsicsK.size > 1) {
Codegen::expand(swizzle, K, intrinsicsK);
}
if (intrinsicsN.size > 1) {
Codegen::expand(swizzle, N, intrinsicsN);
}
} else {
constexpr int M = 0, N = 1;
if (intrinsicsN.size > 1) {
Codegen::expand(swizzle, N, intrinsicsN);
}
if (intrinsicsM.size > 1) {
Codegen::expand(swizzle, M, intrinsicsM);
}
}
return swizzle;
}

static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
MMAIntrinsic intrinsic) {
Type f64 = Float64Type::get(context);
Expand Down Expand Up @@ -356,14 +431,6 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
}
}

static CPUOpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
MMAIntrinsic intrinsic) {
CPUOpaqueMmaLayout o;
std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic);
std::tie(o.mSize, o.nSize, o.kSize) = getMNKShapeFromIntrinsic(intrinsic);
return o;
}

//===----------------------------------------------------------------------===//
// DataTiledMMA Attributes
//===----------------------------------------------------------------------===//
Expand All @@ -385,13 +452,22 @@ void DataTiledMMAAttr::getUndistributedTileTypes(
result.clear();
return;
}
CPUOpaqueMmaLayout o = getOpaqueMMALayout(ctx, intrinsic);
int64_t m = o.mSize * getIntrinsicsM();
int64_t n = o.nSize * getIntrinsicsN();
int64_t k = o.kSize * getIntrinsicsK();
result.assign({VectorType::get({m, k}, o.aType),
VectorType::get({k, n}, o.bType),
VectorType::get({m, n}, o.cType)});
auto lhsSwizzle = getSwizzle(*this, 0);
auto rhsSwizzle = getSwizzle(*this, 1);
auto getTileSize = [](const Codegen::TileSwizzle &swizzle, int srcDimIdx) {
int64_t size = 1;
auto e = swizzle.expandShape[srcDimIdx];
for (auto d : e) {
size *= d.size;
}
return size;
};
int64_t m = getTileSize(lhsSwizzle, 0);
int64_t n = getTileSize(rhsSwizzle, 0);
int64_t k = getTileSize(rhsSwizzle, 1);
auto [aType, bType, cType] = getABCElementTypes(ctx, intrinsic);
result.assign({VectorType::get({m, k}, aType), VectorType::get({k, n}, bType),
VectorType::get({m, n}, cType)});
}

void DataTiledMMAAttr::getDistributedTileTypes(
Expand Down
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move the declarations to IREECPUTypes.h?

When I initiated it, I only wanted to avoid numbers of includes. Some time later, there was a discussion, and I'm +1 on what Ben said: https://discord.com/channels/689900678990135345/689900680009482386/1417237795914907659

I usually just have a FooTypes.h that includes everything but the ops - the ops are often the biggest thing I want to avoid polluting scope with
(attrs/interfaces/etc are type-like, at least, in that they are C++ types... is what I tell myself lol)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright 2026 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_COMPILER_CODEGEN_DIALECT_CPU_IREECPUATTRS_H_
#define IREE_COMPILER_CODEGEN_DIALECT_CPU_IREECPUATTRS_H_

#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUEnums.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h"

#define GET_ATTRDEF_CLASSES
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.h.inc"

namespace mlir::iree_compiler::IREE::CPU {

// Returns the TileSwizzle for the given intrinsic and operand index.
Codegen::TileSwizzle getIntrinsicSwizzle(MMAIntrinsic mma, int operandIdx);

// Returns the TileSwizzle for the given MMA attr and operand index.
Codegen::TileSwizzle getSwizzle(DataTiledMMAAttr mma, int operandIdx);

} // namespace mlir::iree_compiler::IREE::CPU

#endif // IREE_COMPILER_CODEGEN_DIALECT_CPU_IREECPUATTRS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUDialect.h"

#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUDialect.cpp.inc"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"

namespace mlir::iree_compiler::IREE::CPU {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUDialect.h"
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUEnums.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"

namespace mlir::iree_compiler::IREE::CPU {
Expand Down Expand Up @@ -41,8 +40,5 @@ StringRef getTilingLevelName(TilingLevel level);

} // namespace mlir::iree_compiler::IREE::CPU

// clang-format off
#define GET_ATTRDEF_CLASSES
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.h.inc"
// clang-format on
#endif // IREE_COMPILER_CODEGEN_DIALECT_CPU_IREECPUTYPES_H_
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/DialectImplementation.h"

Expand Down Expand Up @@ -44,4 +43,37 @@ TileSwizzle::verify(function_ref<InFlightDiagnostic()> emitError) const {
return success();
}

// Returns the index of the first destination dimension corresponding to the
// given source dimension `srcIdx`.
static size_t expandedDimIdx(const TileSwizzle::ExpandShapeType &expandShape,
size_t srcIdx) {
size_t dstIdx = 0;
for (size_t i = 0; i < srcIdx; ++i) {
dstIdx += expandShape[i].size();
}
return dstIdx;
}

void expand(TileSwizzle &swizzle, size_t srcIdx, TileSwizzle::Dim dim) {
int64_t dstIdx = expandedDimIdx(swizzle.expandShape, srcIdx);
swizzle.expandShape[srcIdx].insert(swizzle.expandShape[srcIdx].begin(), dim);
for (int64_t &p : swizzle.permutation) {
p += (p >= dstIdx);
}
swizzle.permutation.insert(swizzle.permutation.begin(), dstIdx);
}

SmallVector<int64_t>
sliceSwizzledShape(const TileSwizzle &swizzle,
llvm::function_ref<bool(TileSwizzle::Dim)> predicate) {
SmallVector<int64_t> shape;
for (TileSwizzle::ExpandShapeDimVectorType e : swizzle.expandShape) {
for (TileSwizzle::Dim d : e) {
shape.push_back(predicate(d) ? d.size : 1);
}
}
applyPermutationToVector(shape, swizzle.permutation);
return shape;
}

} // namespace mlir::iree_compiler::IREE::Codegen
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,16 @@ struct TileSwizzle {
LogicalResult verify(function_ref<InFlightDiagnostic()> emitError) const;
};

/// Returns the swizzled tile shape, but with dim sizes overwritten with 1 if
/// `predicate` returns false.
SmallVector<int64_t>
sliceSwizzledShape(const TileSwizzle &swizzle,
llvm::function_ref<bool(TileSwizzle::Dim)> predicate);

/// Pushes `dim` to the front of `swizzle.expandShape[srcIdx]`, and updates
/// `swizzle.permutation` accordingly.
void expand(TileSwizzle &swizzle, size_t srcIdx, TileSwizzle::Dim dim);

using ScalableTileFlags = SmallVector<bool>;
/// Container of information needed to materialize the layout transformations.
struct MaterializeEncodingInfo {
Expand Down
Loading
Loading