Skip to content

Commit c7dc7da

Browse files
committed
[Codegen] Migrate ToLayoutOp to VectorizableOpInterface
No new tests because it is an NFC in terms of functionality. It just follows different mechanism for vectorization. It is a step towards https://lists.lfaidata.foundation/g/iree-technical-discussion/message/15 Assisted-by: Claude Signed-off-by: hanhanW <hanhan0912@gmail.com>
1 parent a4dc64b commit c7dc7da

5 files changed

Lines changed: 94 additions & 72 deletions

File tree

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ iree_compiler_cc_library(
5252
deps = [
5353
":PassesIncGen",
5454
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect",
55+
"//compiler/src/iree/compiler/Codegen/Interfaces:VectorizableOpInterface",
5556
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
5657
"//compiler/src/iree/compiler/Dialect/LinalgExt/Utils",
5758
"@llvm-project//llvm:Support",

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ iree_cc_library(
5454
MLIRVectorDialect
5555
MLIRVectorUtils
5656
iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
57+
iree::compiler::Codegen::Interfaces::VectorizableOpInterface
5758
iree::compiler::Dialect::LinalgExt::IR
5859
iree::compiler::Dialect::LinalgExt::Utils
5960
PUBLIC

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Transforms.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#ifndef IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_TRANSFORMS_H_
88
#define IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_TRANSFORMS_H_
99

10+
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
1011
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
1112
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1213
#include "mlir/IR/Builders.h"

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp

Lines changed: 12 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
88
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h"
99
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Transforms.h"
10+
#include "iree/compiler/Codegen/Interfaces/VectorizableOpInterface.h"
1011
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
1112
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
1213
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1314
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14-
#include "mlir/Dialect/UB/IR/UBOps.h"
1515
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1616
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
1717
#include "mlir/IR/Builders.h"
@@ -28,88 +28,28 @@ struct VectorizeToLayoutOpPattern final
2828
: OpRewritePattern<IREE::VectorExt::ToLayoutOp> {
2929
using Base::Base;
3030

31-
vector::TransferReadOp
32-
createReadOp(ImplicitLocOpBuilder &builder,
33-
IREE::VectorExt::ToLayoutOp toLayoutOp) const {
34-
ShapedType inputTy = toLayoutOp.getType();
35-
auto zero = arith::ConstantIndexOp::create(builder, 0);
36-
auto identityMap = builder.getMultiDimIdentityMap(inputTy.getRank());
37-
SmallVector<int64_t> readShape =
38-
toLayoutOp.getLayout().getUndistributedShape();
39-
Value mask = nullptr;
40-
bool needsMask = !toLayoutOp.getType().hasStaticShape() ||
41-
(readShape != inputTy.getShape());
42-
if (needsMask) {
43-
SmallVector<OpFoldResult> mixedSourceDims = tensor::getMixedSizes(
44-
builder, builder.getLoc(), toLayoutOp.getInput());
45-
auto maskType = VectorType::get(readShape, builder.getI1Type());
46-
mask = vector::CreateMaskOp::create(builder, maskType, mixedSourceDims);
47-
}
48-
VectorType vectorType =
49-
VectorType::get(readShape, inputTy.getElementType());
50-
auto inBounds =
51-
builder.getBoolArrayAttr(SmallVector<bool>(vectorType.getRank(), true));
52-
auto padValue = ub::PoisonOp::create(builder, inputTy.getElementType());
53-
auto read = vector::TransferReadOp::create(
54-
builder,
55-
/*type=*/vectorType,
56-
/*source=*/toLayoutOp.getInput(),
57-
/*indices=*/ValueRange{SmallVector<Value>(readShape.size(), zero)},
58-
/*permutation_map=*/identityMap,
59-
/*padding=*/padValue,
60-
/*mask=*/mask,
61-
/*in_bounds=*/inBounds);
62-
return read;
63-
}
64-
65-
vector::TransferWriteOp
66-
createWriteOp(ImplicitLocOpBuilder &builder,
67-
IREE::VectorExt::ToLayoutOp tensorLayoutOp,
68-
Value vectorLayoutOp, Value mask) const {
69-
ShapedType tensorTy = tensorLayoutOp.getType();
70-
auto resType =
71-
RankedTensorType::get(tensorTy.getShape(), tensorTy.getElementType());
72-
auto zero = arith::ConstantIndexOp::create(builder, 0);
73-
int64_t rank = tensorTy.getShape().size();
74-
auto inBounds = builder.getBoolArrayAttr(SmallVector<bool>(rank, true));
75-
auto identityMap = builder.getMultiDimIdentityMap(tensorTy.getRank());
76-
return vector::TransferWriteOp::create(
77-
builder,
78-
/*result=*/resType,
79-
/*vector=*/vectorLayoutOp,
80-
/*source=*/tensorLayoutOp.getInput(),
81-
/*indices=*/ValueRange{SmallVector<Value>(rank, zero)},
82-
/*permutation_map=*/identityMap,
83-
/*mask=*/mask,
84-
/*inBounds=*/inBounds);
85-
}
86-
8731
LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp toLayoutOp,
8832
PatternRewriter &rewriter) const override {
89-
if (!toLayoutOp.hasTensorSemantics()) {
33+
auto vectorizableOp =
34+
cast<VectorizableOpInterface>(toLayoutOp.getOperation());
35+
SmallVector<int64_t> vectorSizes;
36+
SmallVector<bool> scalableDims;
37+
if (!vectorizableOp.isVectorizable(vectorSizes, scalableDims)) {
9038
return failure();
9139
}
92-
OpBuilder::InsertionGuard g(rewriter);
93-
rewriter.setInsertionPoint(toLayoutOp);
94-
Location loc = toLayoutOp.getLoc();
95-
ImplicitLocOpBuilder builder{loc, rewriter};
96-
vector::TransferReadOp readOp = createReadOp(builder, toLayoutOp);
97-
// Create the toLayout operation but with vector types instead.
98-
auto newLayoutOp = IREE::VectorExt::ToLayoutOp::create(
99-
builder, readOp, toLayoutOp.getLayout(),
100-
toLayoutOp.getSharedMemoryConversion());
101-
// Create the write back to a tensor.
102-
vector::TransferWriteOp writeOp =
103-
createWriteOp(builder, toLayoutOp, newLayoutOp, readOp.getMask());
104-
rewriter.replaceOp(toLayoutOp, writeOp);
40+
FailureOr<SmallVector<Value>> result =
41+
vectorizableOp.vectorize(rewriter, vectorSizes, scalableDims);
42+
if (failed(result)) {
43+
return failure();
44+
}
45+
rewriter.replaceOp(toLayoutOp, *result);
10546
return success();
10647
}
10748
};
10849

10950
struct VectorizeIREEVectorExtOpsPass final
11051
: impl::VectorizeIREEVectorExtOpsPassBase<VectorizeIREEVectorExtOpsPass> {
11152
void runOnOperation() override {
112-
11353
MLIRContext *ctx = &getContext();
11454
RewritePatternSet patterns(ctx);
11555
patterns.add<VectorizeToLayoutOpPattern>(ctx);

compiler/src/iree/compiler/Codegen/Interfaces/VectorizableOpInterface.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,80 @@ struct ArgCompareOpVectorizationModel
310310
}
311311
};
312312

313+
struct ToLayoutOpVectorizationModel
314+
: public VectorizableOpInterface::ExternalModel<
315+
ToLayoutOpVectorizationModel, IREE::VectorExt::ToLayoutOp> {
316+
317+
bool isVectorizable(Operation *op, ArrayRef<int64_t> vectorSizes,
318+
ArrayRef<bool> scalableDims,
319+
DictionaryAttr options) const {
320+
auto toLayoutOp = cast<IREE::VectorExt::ToLayoutOp>(op);
321+
return toLayoutOp.hasTensorSemantics();
322+
}
323+
324+
FailureOr<SmallVector<Value>> vectorize(Operation *op, RewriterBase &rewriter,
325+
ArrayRef<int64_t> vectorSizes,
326+
ArrayRef<bool> scalableDims,
327+
DictionaryAttr options) const {
328+
auto toLayoutOp = cast<IREE::VectorExt::ToLayoutOp>(op);
329+
if (!toLayoutOp.hasTensorSemantics()) {
330+
return failure();
331+
}
332+
OpBuilder::InsertionGuard g(rewriter);
333+
rewriter.setInsertionPoint(toLayoutOp);
334+
Location loc = toLayoutOp.getLoc();
335+
ShapedType inputTy = toLayoutOp.getType();
336+
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
337+
auto identityMap = rewriter.getMultiDimIdentityMap(inputTy.getRank());
338+
SmallVector<int64_t> readShape =
339+
toLayoutOp.getLayout().getUndistributedShape();
340+
Value mask = nullptr;
341+
if (!toLayoutOp.getType().hasStaticShape()) {
342+
SmallVector<OpFoldResult> mixedSourceDims =
343+
tensor::getMixedSizes(rewriter, loc, toLayoutOp.getInput());
344+
auto maskType = VectorType::get(readShape, rewriter.getI1Type());
345+
mask = vector::CreateMaskOp::create(rewriter, loc, maskType,
346+
mixedSourceDims);
347+
}
348+
VectorType vectorType =
349+
VectorType::get(readShape, inputTy.getElementType());
350+
auto inBounds = rewriter.getBoolArrayAttr(
351+
SmallVector<bool>(vectorType.getRank(), true));
352+
auto padValue =
353+
ub::PoisonOp::create(rewriter, loc, inputTy.getElementType());
354+
auto readOp = vector::TransferReadOp::create(
355+
rewriter, loc,
356+
/*type=*/vectorType,
357+
/*source=*/toLayoutOp.getInput(),
358+
/*indices=*/ValueRange{SmallVector<Value>(readShape.size(), zero)},
359+
/*permutation_map=*/identityMap,
360+
/*padding=*/padValue,
361+
/*mask=*/mask,
362+
/*in_bounds=*/inBounds);
363+
// Create the toLayout operation but with vector types instead.
364+
auto newLayoutOp = IREE::VectorExt::ToLayoutOp::create(
365+
rewriter, loc, readOp, toLayoutOp.getLayout(),
366+
toLayoutOp.getSharedMemoryConversion());
367+
// Create the write back to a tensor.
368+
ShapedType tensorTy = toLayoutOp.getType();
369+
auto resType =
370+
RankedTensorType::get(tensorTy.getShape(), tensorTy.getElementType());
371+
int64_t rank = tensorTy.getShape().size();
372+
auto writeInBounds =
373+
rewriter.getBoolArrayAttr(SmallVector<bool>(rank, true));
374+
auto writeIdentityMap = rewriter.getMultiDimIdentityMap(tensorTy.getRank());
375+
auto writeOp = vector::TransferWriteOp::create(
376+
rewriter, loc,
377+
/*result=*/resType,
378+
/*vector=*/newLayoutOp,
379+
/*source=*/toLayoutOp.getInput(),
380+
/*indices=*/ValueRange{SmallVector<Value>(rank, zero)},
381+
/*permutation_map=*/writeIdentityMap,
382+
/*mask=*/mask,
383+
/*inBounds=*/writeInBounds);
384+
return SmallVector<Value>{writeOp.getResult()};
385+
}
386+
};
313387
} // namespace
314388

315389
void registerVectorizableOpInterfaceExternalModels(DialectRegistry &registry) {
@@ -320,6 +394,11 @@ void registerVectorizableOpInterfaceExternalModels(DialectRegistry &registry) {
320394
IREE::LinalgExt::ArgCompareOp::attachInterface<
321395
ArgCompareOpVectorizationModel>(*ctx);
322396
});
397+
registry.addExtension(+[](MLIRContext *ctx,
398+
IREE::VectorExt::IREEVectorExtDialect *dialect) {
399+
IREE::VectorExt::ToLayoutOp::attachInterface<ToLayoutOpVectorizationModel>(
400+
*ctx);
401+
});
323402
}
324403

325404
} // namespace mlir::iree_compiler

0 commit comments

Comments
 (0)