Skip to content

Commit 4c1bb13

Browse files
committed
[Codegen] Migrate MapStoreOp to VectorizableOpInterface
The revision also deletes VectorizeIREELinalgExtOps, because it is already covered in GenericVectorization pass. No new tests because it is an NFC in terms of functionality. It just follows different mechanism for vectorization. It is a step towards https://lists.lfaidata.foundation/g/iree-technical-discussion/message/15 Assisted-by: Claude Signed-off-by: hanhanW <hanhan0912@gmail.com>
1 parent a23ccb1 commit 4c1bb13

12 files changed

Lines changed: 227 additions & 264 deletions

File tree

compiler/src/iree/compiler/Codegen/Common/test/generic_vectorization.mlir

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,3 +1001,145 @@ func.func @arg_compare_with_index_base(%input: tensor<4x128xf32>,
10011001
// CHECK: %[[WRITE_VAL:.+]] = vector.transfer_write %[[RESULT_VAL]], %[[OUT_VAL]]
10021002
// CHECK: %[[WRITE_IDX:.+]] = vector.transfer_write %[[RESULT_IDX]], %[[OUT_IDX]]
10031003
// CHECK: return %[[WRITE_VAL]], %[[WRITE_IDX]]
1004+
1005+
// -----
1006+
1007+
func.func @map_store(
1008+
%input: tensor<4x16x64xf32>, %output: tensor<4x16x64xf32>
1009+
) -> tensor<4x16x64xf32> {
1010+
%0 = iree_linalg_ext.map_store %input into %output {
1011+
^bb0(%idx0: index, %idx1: index, %idx2: index):
1012+
%mask = arith.constant true
1013+
iree_linalg_ext.yield %idx0, %idx1, %idx2, %mask : index, index, index, i1
1014+
} : tensor<4x16x64xf32> into tensor<4x16x64xf32> -> tensor<4x16x64xf32>
1015+
return %0 : tensor<4x16x64xf32>
1016+
}
1017+
// CHECK-LABEL: @map_store
1018+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]
1019+
// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]]
1020+
// CHECK: %[[READ:.+]] = vector.transfer_read %[[INPUT]]
1021+
// CHECK: %[[MAP_SCATTER:.+]] = iree_linalg_ext.map_store
1022+
// CHECK-SAME: %[[READ]] into %[[OUTPUT]]
1023+
// CHECK: : vector<4x16x64xf32> into tensor<4x16x64xf32> -> tensor<4x16x64xf32>
1024+
// CHECK: return %[[MAP_SCATTER]] : tensor<4x16x64xf32>
1025+
1026+
// -----
1027+
1028+
func.func @no_vectorize_map_store_dynamic(
1029+
%input: tensor<?xf32>, %output: tensor<64xf32>
1030+
) -> tensor<64xf32> {
1031+
%0 = iree_linalg_ext.map_store %input into %output {
1032+
^bb0(%idx0: index):
1033+
%mask = arith.constant true
1034+
iree_linalg_ext.yield %idx0, %mask : index, i1
1035+
} : tensor<?xf32> into tensor<64xf32> -> tensor<64xf32>
1036+
return %0 : tensor<64xf32>
1037+
}
1038+
// CHECK-LABEL: @no_vectorize_map_store_dynamic
1039+
// CHECK-NOT: vector
1040+
1041+
// -----
1042+
1043+
func.func @map_store_f4_multiple_of_byte(
1044+
%input: tensor<2x2xf4E2M1FN>, %output: tensor<2x2xf4E2M1FN>
1045+
) -> tensor<2x2xf4E2M1FN> {
1046+
%0 = iree_linalg_ext.map_store %input into %output {
1047+
^bb0(%idx0: index, %idx1: index):
1048+
%mask = arith.constant true
1049+
iree_linalg_ext.yield %idx0, %idx1, %mask : index, index, i1
1050+
} : tensor<2x2xf4E2M1FN> into tensor<2x2xf4E2M1FN> -> tensor<2x2xf4E2M1FN>
1051+
return %0 : tensor<2x2xf4E2M1FN>
1052+
}
1053+
// CHECK-LABEL: @map_store_f4_multiple_of_byte
1054+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]
1055+
// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]]
1056+
// CHECK: %[[READ:.+]] = vector.transfer_read %[[INPUT]]
1057+
// CHECK: %[[MAP_SCATTER:.+]] = iree_linalg_ext.map_store
1058+
// CHECK-SAME: %[[READ]] into %[[OUTPUT]]
1059+
// CHECK: : vector<2x2xf4E2M1FN> into tensor<2x2xf4E2M1FN> -> tensor<2x2xf4E2M1FN>
1060+
// CHECK: return %[[MAP_SCATTER]] : tensor<2x2xf4E2M1FN>
1061+
1062+
// -----
1063+
1064+
func.func @map_store_f4_not_multiple_of_byte(
1065+
%input: tensor<2x1xf4E2M1FN>, %output: tensor<2x2xf4E2M1FN>
1066+
) -> tensor<2x2xf4E2M1FN> {
1067+
%0 = iree_linalg_ext.map_store %input into %output {
1068+
^bb0(%idx0: index, %idx1: index):
1069+
%mask = arith.constant true
1070+
iree_linalg_ext.yield %idx0, %idx1, %mask : index, index, i1
1071+
} : tensor<2x1xf4E2M1FN> into tensor<2x2xf4E2M1FN> -> tensor<2x2xf4E2M1FN>
1072+
return %0 : tensor<2x2xf4E2M1FN>
1073+
}
1074+
// CHECK-LABEL: @map_store_f4_not_multiple_of_byte
1075+
// CHECK-NOT: vector
1076+
1077+
// -----
1078+
1079+
func.func @map_store_f4_unit_stride(
1080+
%input: tensor<2x2xf4E2M1FN>, %output: tensor<2x4xf4E2M1FN>
1081+
) -> tensor<2x4xf4E2M1FN> {
1082+
%0 = iree_linalg_ext.map_store %input into %output {
1083+
^bb0(%idx0: index, %idx1: index):
1084+
%mask = arith.constant true
1085+
%1 = affine.apply affine_map<(d0) -> (d0 + 2)>(%idx1)
1086+
iree_linalg_ext.yield %idx0, %1, %mask : index, index, i1
1087+
} : tensor<2x2xf4E2M1FN> into tensor<2x4xf4E2M1FN> -> tensor<2x4xf4E2M1FN>
1088+
return %0 : tensor<2x4xf4E2M1FN>
1089+
}
1090+
// CHECK-LABEL: @map_store_f4_unit_stride
1091+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]
1092+
// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9_]+]]
1093+
// CHECK: %[[READ:.+]] = vector.transfer_read %[[INPUT]]
1094+
// CHECK: %[[MAP_SCATTER:.+]] = iree_linalg_ext.map_store
1095+
// CHECK-SAME: %[[READ]] into %[[OUTPUT]]
1096+
// CHECK: : vector<2x2xf4E2M1FN> into tensor<2x4xf4E2M1FN> -> tensor<2x4xf4E2M1FN>
1097+
// CHECK: return %[[MAP_SCATTER]] : tensor<2x4xf4E2M1FN>
1098+
1099+
// -----
1100+
1101+
func.func @map_store_f4_not_unit_stride(
1102+
%input: tensor<2x2xf4E2M1FN>, %output: tensor<2x4xf4E2M1FN>
1103+
) -> tensor<2x4xf4E2M1FN> {
1104+
%0 = iree_linalg_ext.map_store %input into %output {
1105+
^bb0(%idx0: index, %idx1: index):
1106+
%mask = arith.constant true
1107+
%1 = affine.apply affine_map<(d0) -> (d0 * 2)>(%idx1)
1108+
iree_linalg_ext.yield %idx0, %1, %mask : index, index, i1
1109+
} : tensor<2x2xf4E2M1FN> into tensor<2x4xf4E2M1FN> -> tensor<2x4xf4E2M1FN>
1110+
return %0 : tensor<2x4xf4E2M1FN>
1111+
}
1112+
// CHECK-LABEL: @map_store_f4_not_unit_stride
1113+
// CHECK-NOT: vector
1114+
1115+
// -----
1116+
1117+
func.func @map_store_f4_not_index_applied_multiple_times(
1118+
%input: tensor<2x2xf4E2M1FN>, %output: tensor<2x4xf4E2M1FN>
1119+
) -> tensor<2x4xf4E2M1FN> {
1120+
%0 = iree_linalg_ext.map_store %input into %output {
1121+
^bb0(%idx0: index, %idx1: index):
1122+
%mask = arith.constant true
1123+
%1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%idx1, %idx1)
1124+
iree_linalg_ext.yield %idx0, %1, %mask : index, index, i1
1125+
} : tensor<2x2xf4E2M1FN> into tensor<2x4xf4E2M1FN> -> tensor<2x4xf4E2M1FN>
1126+
return %0 : tensor<2x4xf4E2M1FN>
1127+
}
1128+
// CHECK-LABEL: @map_store_f4_not_index_applied_multiple_times
1129+
// CHECK-NOT: vector
1130+
1131+
// -----
1132+
1133+
func.func @map_store_f4_mask_depends_on_inner_index(
1134+
%input: tensor<2x2xf4E2M1FN>, %output: tensor<2x4xf4E2M1FN>
1135+
) -> tensor<2x4xf4E2M1FN> {
1136+
%0 = iree_linalg_ext.map_store %input into %output {
1137+
^bb0(%idx0: index, %idx1: index):
1138+
%c1 = arith.constant 1 : index
1139+
%mask = arith.cmpi uge, %idx1, %c1 : index
1140+
iree_linalg_ext.yield %idx0, %idx1, %mask : index, index, i1
1141+
} : tensor<2x2xf4E2M1FN> into tensor<2x4xf4E2M1FN> -> tensor<2x4xf4E2M1FN>
1142+
return %0 : tensor<2x4xf4E2M1FN>
1143+
}
1144+
// CHECK-LABEL: @map_store_f4_mask_depends_on_inner_index
1145+
// CHECK-NOT: vector

compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ iree_compiler_cc_library(
233233
":VectorizableOpInterfaceGen",
234234
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect",
235235
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
236+
"//compiler/src/iree/compiler/Utils",
237+
"@llvm-project//mlir:Analysis",
236238
"@llvm-project//mlir:ArithDialect",
237239
"@llvm-project//mlir:IR",
238240
"@llvm-project//mlir:TensorDialect",

compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,15 @@ iree_cc_library(
169169
"VectorizableOpInterface.cpp"
170170
DEPS
171171
::VectorizableOpInterfaceGen
172+
MLIRAnalysis
172173
MLIRArithDialect
173174
MLIRIR
174175
MLIRTensorDialect
175176
MLIRUBDialect
176177
MLIRVectorDialect
177178
iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
178179
iree::compiler::Dialect::LinalgExt::IR
180+
iree::compiler::Utils
179181
PUBLIC
180182
)
181183

compiler/src/iree/compiler/Codegen/Interfaces/VectorizableOpInterface.cpp

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,14 @@
88

99
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
1010
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
11+
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Transforms.h"
1112
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
1213
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
14+
#include "iree/compiler/Utils/Indexing.h"
15+
#include "mlir/Analysis/SliceAnalysis.h"
1316
#include "mlir/Dialect/Arith/IR/Arith.h"
17+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
18+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1419
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1520
#include "mlir/Dialect/UB/IR/UBOps.h"
1621
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -385,16 +390,85 @@ struct ToLayoutOpVectorizationModel
385390
}
386391
};
387392

393+
struct MapStoreOpVectorizationModel
394+
: public VectorizableOpInterface::ExternalModel<
395+
MapStoreOpVectorizationModel, IREE::LinalgExt::MapStoreOp> {
396+
397+
bool isVectorizable(Operation *op, ArrayRef<int64_t> vectorSizes,
398+
ArrayRef<bool> scalableDims,
399+
DictionaryAttr options) const {
400+
auto mapStoreOp = cast<IREE::LinalgExt::MapStoreOp>(op);
401+
if (mapStoreOp.isVectorized()) {
402+
return false;
403+
}
404+
ShapedType inputType = mapStoreOp.getInputType();
405+
if (!inputType.hasStaticShape()) {
406+
return false;
407+
}
408+
const int64_t innerSize = inputType.getShape()[inputType.getRank() - 1];
409+
const int64_t bitWidth = inputType.getElementTypeBitWidth();
410+
if ((innerSize * bitWidth % 8) != 0) {
411+
return false;
412+
}
413+
// In case of a sub-byte bitwidth, we check that there is a contiguous copy
414+
// on the inner dimension that is a multiple of a byte. Note that the mask
415+
// shouldn't depend on the inner index for this.
416+
if (bitWidth < 8) {
417+
// First check that the mask is not the forward slice of the inner index.
418+
Value innermostInputIdx =
419+
mapStoreOp.getInputIndex(mapStoreOp.getInputRank() - 1);
420+
SetVector<Operation *> slice;
421+
getForwardSlice(innermostInputIdx, &slice);
422+
Operation *maskOp = mapStoreOp.getMask().getDefiningOp();
423+
if (maskOp && slice.contains(maskOp)) {
424+
return false;
425+
}
426+
// Next check that the inner index of the yield is a unit function of
427+
// the inner input index.
428+
Value innermostOutputIdx =
429+
mapStoreOp.getOutputIndex(mapStoreOp.getOutputRank() - 1);
430+
if (!isUnitFunctionOf(innermostOutputIdx, innermostInputIdx)) {
431+
return false;
432+
}
433+
}
434+
return true;
435+
}
436+
437+
FailureOr<SmallVector<Value>> vectorize(Operation *op, RewriterBase &rewriter,
438+
ArrayRef<int64_t> vectorSizes,
439+
ArrayRef<bool> scalableDims,
440+
DictionaryAttr options) const {
441+
auto mapStoreOp = cast<IREE::LinalgExt::MapStoreOp>(op);
442+
Location loc = mapStoreOp.getLoc();
443+
rewriter.setInsertionPoint(mapStoreOp);
444+
ShapedType inputType = mapStoreOp.getInputType();
445+
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
446+
SmallVector<Value> zeros(inputType.getRank(), zero);
447+
auto inputVectorType =
448+
VectorType::get(inputType.getShape(), inputType.getElementType());
449+
Value inputVector = vector::TransferReadOp::create(
450+
rewriter, loc, inputVectorType, mapStoreOp.getInput(),
451+
/*indices=*/zeros,
452+
/*padding=*/std::nullopt);
453+
auto vectorizedMapStoreOp =
454+
clone(rewriter, mapStoreOp, mapStoreOp.getResultTypes(),
455+
{inputVector, mapStoreOp.getOutput()});
456+
return SmallVector<Value>(vectorizedMapStoreOp->getResults());
457+
}
458+
};
459+
388460
} // namespace
389461

390462
void registerVectorizableOpInterfaceExternalModels(DialectRegistry &registry) {
391-
registry.addExtension(
392-
+[](MLIRContext *ctx, IREE::LinalgExt::IREELinalgExtDialect *dialect) {
393-
IREE::LinalgExt::GatherOp::attachInterface<GatherOpVectorizationModel>(
394-
*ctx);
395-
IREE::LinalgExt::ArgCompareOp::attachInterface<
396-
ArgCompareOpVectorizationModel>(*ctx);
397-
});
463+
registry.addExtension(+[](MLIRContext *ctx,
464+
IREE::LinalgExt::IREELinalgExtDialect *dialect) {
465+
IREE::LinalgExt::GatherOp::attachInterface<GatherOpVectorizationModel>(
466+
*ctx);
467+
IREE::LinalgExt::ArgCompareOp::attachInterface<
468+
ArgCompareOpVectorizationModel>(*ctx);
469+
IREE::LinalgExt::MapStoreOp::attachInterface<MapStoreOpVectorizationModel>(
470+
*ctx);
471+
});
398472
registry.addExtension(+[](MLIRContext *ctx,
399473
IREE::VectorExt::IREEVectorExtDialect *dialect) {
400474
IREE::VectorExt::ToLayoutOp::attachInterface<ToLayoutOpVectorizationModel>(

compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -573,8 +573,6 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
573573
funcPassManager.addPass(createGPUCombineValueSemanticBarriersPass());
574574

575575
// Step 6. Lower special ops and vectorize.
576-
funcPassManager.addPass(
577-
IREE::LinalgExt::createVectorizeIREELinalgExtOpsPass());
578576
funcPassManager.addPass(IREE::GPU::createVectorizeIREEGPUOpsPass());
579577
addGPUVectorizationPasses(funcPassManager, /*vectorizeCopies=*/false,
580578
/*enableMasking=*/true,
@@ -836,8 +834,6 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
836834
funcPassManager.addPass(tensor::createFoldTensorSubsetOpsPass());
837835

838836
// Linalg -> Vector
839-
funcPassManager.addPass(
840-
IREE::LinalgExt::createVectorizeIREELinalgExtOpsPass());
841837
addGPUVectorizationPasses(funcPassManager, /*vectorizeCopies=*/true,
842838
/*enableMasking=*/true);
843839

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ iree_compiler_cc_library(
4848
"TestReshapeFusion.cpp",
4949
"TileAttention.cpp",
5050
"TransposeFusion.cpp",
51-
"VectorizeIREELinalgExtOps.cpp",
5251
],
5352
hdrs = [
5453
"LoopMappingUtils.h",

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ iree_cc_library(
4646
"TestReshapeFusion.cpp"
4747
"TileAttention.cpp"
4848
"TransposeFusion.cpp"
49-
"VectorizeIREELinalgExtOps.cpp"
5049
DEPS
5150
::PassesIncGen
5251
LLVMSupport

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,4 @@ def TestReshapeFusionPass :
146146
let summary = "Test reshape fusion patterns";
147147
}
148148

149-
def VectorizeIREELinalgExtOpsPass :
150-
InterfacePass<"iree-linalg-ext-vectorize-ops", "mlir::FunctionOpInterface"> {
151-
let summary = "Convert linalg_ext ops into their vector form.";
152-
let dependentDialects = [
153-
"::mlir::vector::VectorDialect",
154-
"::mlir::arith::ArithDialect"
155-
];
156-
}
157-
158149
#endif // IREE_DIALECT_LINALGEXT_PASSES

0 commit comments

Comments
 (0)