77#include " iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
88#include " iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h"
99#include " iree/compiler/Codegen/Dialect/VectorExt/Transforms/Transforms.h"
10+ #include " iree/compiler/Codegen/Interfaces/VectorizableOpInterface.h"
1011#include " iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
1112#include " iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
1213#include " mlir/Dialect/Linalg/IR/Linalg.h"
1314#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
14- #include " mlir/Dialect/UB/IR/UBOps.h"
1515#include " mlir/Dialect/Vector/IR/VectorOps.h"
1616#include " mlir/Dialect/Vector/Utils/VectorUtils.h"
1717#include " mlir/IR/Builders.h"
@@ -28,88 +28,28 @@ struct VectorizeToLayoutOpPattern final
2828 : OpRewritePattern<IREE::VectorExt::ToLayoutOp> {
2929 using Base::Base;
3030
31- vector::TransferReadOp
32- createReadOp (ImplicitLocOpBuilder &builder,
33- IREE::VectorExt::ToLayoutOp toLayoutOp) const {
34- ShapedType inputTy = toLayoutOp.getType ();
35- auto zero = arith::ConstantIndexOp::create (builder, 0 );
36- auto identityMap = builder.getMultiDimIdentityMap (inputTy.getRank ());
37- SmallVector<int64_t > readShape =
38- toLayoutOp.getLayout ().getUndistributedShape ();
39- Value mask = nullptr ;
40- bool needsMask = !toLayoutOp.getType ().hasStaticShape () ||
41- (readShape != inputTy.getShape ());
42- if (needsMask) {
43- SmallVector<OpFoldResult> mixedSourceDims = tensor::getMixedSizes (
44- builder, builder.getLoc (), toLayoutOp.getInput ());
45- auto maskType = VectorType::get (readShape, builder.getI1Type ());
46- mask = vector::CreateMaskOp::create (builder, maskType, mixedSourceDims);
47- }
48- VectorType vectorType =
49- VectorType::get (readShape, inputTy.getElementType ());
50- auto inBounds =
51- builder.getBoolArrayAttr (SmallVector<bool >(vectorType.getRank (), true ));
52- auto padValue = ub::PoisonOp::create (builder, inputTy.getElementType ());
53- auto read = vector::TransferReadOp::create (
54- builder,
55- /* type=*/ vectorType,
56- /* source=*/ toLayoutOp.getInput (),
57- /* indices=*/ ValueRange{SmallVector<Value>(readShape.size (), zero)},
58- /* permutation_map=*/ identityMap,
59- /* padding=*/ padValue,
60- /* mask=*/ mask,
61- /* in_bounds=*/ inBounds);
62- return read;
63- }
64-
65- vector::TransferWriteOp
66- createWriteOp (ImplicitLocOpBuilder &builder,
67- IREE::VectorExt::ToLayoutOp tensorLayoutOp,
68- Value vectorLayoutOp, Value mask) const {
69- ShapedType tensorTy = tensorLayoutOp.getType ();
70- auto resType =
71- RankedTensorType::get (tensorTy.getShape (), tensorTy.getElementType ());
72- auto zero = arith::ConstantIndexOp::create (builder, 0 );
73- int64_t rank = tensorTy.getShape ().size ();
74- auto inBounds = builder.getBoolArrayAttr (SmallVector<bool >(rank, true ));
75- auto identityMap = builder.getMultiDimIdentityMap (tensorTy.getRank ());
76- return vector::TransferWriteOp::create (
77- builder,
78- /* result=*/ resType,
79- /* vector=*/ vectorLayoutOp,
80- /* source=*/ tensorLayoutOp.getInput (),
81- /* indices=*/ ValueRange{SmallVector<Value>(rank, zero)},
82- /* permutation_map=*/ identityMap,
83- /* mask=*/ mask,
84- /* inBounds=*/ inBounds);
85- }
86-
8731 LogicalResult matchAndRewrite (IREE::VectorExt::ToLayoutOp toLayoutOp,
8832 PatternRewriter &rewriter) const override {
89- if (!toLayoutOp.hasTensorSemantics ()) {
33+ auto vectorizableOp =
34+ cast<VectorizableOpInterface>(toLayoutOp.getOperation ());
35+ SmallVector<int64_t > vectorSizes;
36+ SmallVector<bool > scalableDims;
37+ if (!vectorizableOp.isVectorizable (vectorSizes, scalableDims)) {
9038 return failure ();
9139 }
92- OpBuilder::InsertionGuard g (rewriter);
93- rewriter.setInsertionPoint (toLayoutOp);
94- Location loc = toLayoutOp.getLoc ();
95- ImplicitLocOpBuilder builder{loc, rewriter};
96- vector::TransferReadOp readOp = createReadOp (builder, toLayoutOp);
97- // Create the toLayout operation but with vector types instead.
98- auto newLayoutOp = IREE::VectorExt::ToLayoutOp::create (
99- builder, readOp, toLayoutOp.getLayout (),
100- toLayoutOp.getSharedMemoryConversion ());
101- // Create the write back to a tensor.
102- vector::TransferWriteOp writeOp =
103- createWriteOp (builder, toLayoutOp, newLayoutOp, readOp.getMask ());
104- rewriter.replaceOp (toLayoutOp, writeOp);
40+ FailureOr<SmallVector<Value>> result =
41+ vectorizableOp.vectorize (rewriter, vectorSizes, scalableDims);
42+ if (failed (result)) {
43+ return failure ();
44+ }
45+ rewriter.replaceOp (toLayoutOp, *result);
10546 return success ();
10647 }
10748};
10849
10950struct VectorizeIREEVectorExtOpsPass final
11051 : impl::VectorizeIREEVectorExtOpsPassBase<VectorizeIREEVectorExtOpsPass> {
11152 void runOnOperation () override {
112-
11353 MLIRContext *ctx = &getContext ();
11454 RewritePatternSet patterns (ctx);
11555 patterns.add <VectorizeToLayoutOpPattern>(ctx);
0 commit comments