Skip to content

Commit 2f6bc47

Browse files
authored
[mlir][vector] Standardise valueToStore Naming Across Vector Ops (NFC) (llvm#134206)
This change standardises the naming convention for the argument representing the value to store in various vector operations. Specifically, it ensures that all vector ops storing a value—whether into memory, a tensor, or another vector — use `valueToStore` for the corresponding argument name. Updated operations: * `vector.transfer_write`, `vector.insert`, `vector.scalable_insert`, `vector.insert_strided_slice`. For reference, here are operations that currently use `valueToStore`: * `vector.store` `vector.scatter`, `vector.compressstore`, `vector.maskedstore`. This change is non-functional (NFC) and does not affect the functionality of these operations. Implements llvm#131602
1 parent bafa2f4 commit 2f6bc47

File tree

16 files changed

+119
-86
lines changed

16 files changed

+119
-86
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

+22-21
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ def Vector_InsertOp :
907907
}];
908908

909909
let arguments = (ins
910-
AnyType:$source,
910+
AnyType:$valueToStore,
911911
AnyVectorOfAnyRank:$dest,
912912
Variadic<Index>:$dynamic_position,
913913
DenseI64ArrayAttr:$static_position
@@ -916,15 +916,15 @@ def Vector_InsertOp :
916916

917917
let builders = [
918918
// Builder to insert a scalar/rank-0 vector into a rank-0 vector.
919-
OpBuilder<(ins "Value":$source, "Value":$dest)>,
920-
OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
921-
OpBuilder<(ins "Value":$source, "Value":$dest, "OpFoldResult":$position)>,
922-
OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<int64_t>":$position)>,
923-
OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
919+
OpBuilder<(ins "Value":$valueToStore, "Value":$dest)>,
920+
OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "int64_t":$position)>,
921+
OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "OpFoldResult":$position)>,
922+
OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "ArrayRef<int64_t>":$position)>,
923+
OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
924924
];
925925

926926
let extraClassDeclaration = extraPoisonClassDeclaration # [{
927-
Type getSourceType() { return getSource().getType(); }
927+
Type getValueToStoreType() { return getValueToStore().getType(); }
928928
VectorType getDestVectorType() {
929929
return ::llvm::cast<VectorType>(getDest().getType());
930930
}
@@ -946,8 +946,8 @@ def Vector_InsertOp :
946946
}];
947947

948948
let assemblyFormat = [{
949-
$source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
950-
attr-dict `:` type($source) `into` type($dest)
949+
$valueToStore `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
950+
attr-dict `:` type($valueToStore) `into` type($dest)
951951
}];
952952

953953
let hasCanonicalizer = 1;
@@ -957,13 +957,13 @@ def Vector_InsertOp :
957957

958958
def Vector_ScalableInsertOp :
959959
Vector_Op<"scalable.insert", [Pure,
960-
AllElementTypesMatch<["source", "dest"]>,
960+
AllElementTypesMatch<["valueToStore", "dest"]>,
961961
AllTypesMatch<["dest", "res"]>,
962962
PredOpTrait<"position is a multiple of the source length.",
963963
CPred<
964964
"(getPos() % getSourceVectorType().getNumElements()) == 0"
965965
>>]>,
966-
Arguments<(ins VectorOfRank<[1]>:$source,
966+
Arguments<(ins VectorOfRank<[1]>:$valueToStore,
967967
ScalableVectorOfRank<[1]>:$dest,
968968
I64Attr:$pos)>,
969969
Results<(outs ScalableVectorOfRank<[1]>:$res)> {
@@ -999,12 +999,12 @@ def Vector_ScalableInsertOp :
999999
}];
10001000

10011001
let assemblyFormat = [{
1002-
$source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest)
1002+
$valueToStore `,` $dest `[` $pos `]` attr-dict `:` type($valueToStore) `into` type($dest)
10031003
}];
10041004

10051005
let extraClassDeclaration = extraPoisonClassDeclaration # [{
10061006
VectorType getSourceVectorType() {
1007-
return ::llvm::cast<VectorType>(getSource().getType());
1007+
return ::llvm::cast<VectorType>(getValueToStore().getType());
10081008
}
10091009
VectorType getDestVectorType() {
10101010
return ::llvm::cast<VectorType>(getDest().getType());
@@ -1068,20 +1068,20 @@ def Vector_InsertStridedSliceOp :
10681068
PredOpTrait<"operand #0 and result have same element type",
10691069
TCresVTEtIsSameAsOpBase<0, 0>>,
10701070
AllTypesMatch<["dest", "res"]>]>,
1071-
Arguments<(ins AnyVectorOfNonZeroRank:$source, AnyVectorOfNonZeroRank:$dest, I64ArrayAttr:$offsets,
1071+
Arguments<(ins AnyVectorOfNonZeroRank:$valueToStore, AnyVectorOfNonZeroRank:$dest, I64ArrayAttr:$offsets,
10721072
I64ArrayAttr:$strides)>,
10731073
Results<(outs AnyVectorOfNonZeroRank:$res)> {
10741074
let summary = "strided_slice operation";
10751075
let description = [{
1076-
Takes a k-D source vector, an n-D destination vector (n >= k), n-sized
1076+
Takes a k-D valueToStore vector, an n-D destination vector (n >= k), n-sized
10771077
`offsets` integer array attribute, a k-sized `strides` integer array attribute
1078-
and inserts the k-D source vector as a strided subvector at the proper offset
1078+
and inserts the k-D valueToStore vector as a strided subvector at the proper offset
10791079
into the n-D destination vector.
10801080

10811081
At the moment strides must contain only 1s.
10821082

10831083
Returns an n-D vector that is a copy of the n-D destination vector in which
1084-
the last k-D dimensions contain the k-D source vector elements strided at
1084+
the last k-D dimensions contain the k-D valueToStore vector elements strided at
10851085
the proper location as specified by the offsets.
10861086

10871087
Example:
@@ -1094,16 +1094,17 @@ def Vector_InsertStridedSliceOp :
10941094
}];
10951095

10961096
let assemblyFormat = [{
1097-
$source `,` $dest attr-dict `:` type($source) `into` type($dest)
1097+
$valueToStore `,` $dest attr-dict `:` type($valueToStore) `into` type($dest)
10981098
}];
10991099

11001100
let builders = [
1101-
OpBuilder<(ins "Value":$source, "Value":$dest,
1101+
OpBuilder<(ins "Value":$valueToStore, "Value":$dest,
11021102
"ArrayRef<int64_t>":$offsets, "ArrayRef<int64_t>":$strides)>
11031103
];
11041104
let extraClassDeclaration = [{
1105+
// TODO: Rename
11051106
VectorType getSourceVectorType() {
1106-
return ::llvm::cast<VectorType>(getSource().getType());
1107+
return ::llvm::cast<VectorType>(getValueToStore().getType());
11071108
}
11081109
VectorType getDestVectorType() {
11091110
return ::llvm::cast<VectorType>(getDest().getType());
@@ -1520,7 +1521,7 @@ def Vector_TransferWriteOp :
15201521
AttrSizedOperandSegments,
15211522
DestinationStyleOpInterface
15221523
]>,
1523-
Arguments<(ins AnyVectorOfAnyRank:$vector,
1524+
Arguments<(ins AnyVectorOfAnyRank:$valueToStore,
15241525
AnyShaped:$source,
15251526
Variadic<Index>:$indices,
15261527
AffineMapAttr:$permutation_map,

mlir/include/mlir/Interfaces/VectorInterfaces.td

+9-5
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,14 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
124124
/*methodName=*/"getVector",
125125
/*args=*/(ins)
126126
>,
127+
InterfaceMethod<
128+
/*desc=*/[{
129+
Return the type of the vector that this operation operates on.
130+
}],
131+
/*retTy=*/"::mlir::VectorType",
132+
/*methodName=*/"getVectorType",
133+
/*args=*/(ins)
134+
>,
127135
InterfaceMethod<
128136
/*desc=*/[{
129137
Return the indices that specify the starting offsets into the source
@@ -133,6 +141,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
133141
/*methodName=*/"getIndices",
134142
/*args=*/(ins)
135143
>,
144+
136145
InterfaceMethod<
137146
/*desc=*/[{
138147
Return the permutation map that describes the mapping of vector
@@ -202,11 +211,6 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
202211
return $_op.getPermutationMap().getNumResults();
203212
}
204213

205-
/// Return the type of the vector that this operation operates on.
206-
::mlir::VectorType getVectorType() {
207-
return ::llvm::cast<::mlir::VectorType>($_op.getVector().getType());
208-
}
209-
210214
/// Return "true" if at least one of the vector dimensions is a broadcasted
211215
/// dimension.
212216
bool hasBroadcastDim() {

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ struct VectorInsertToArmSMELowering
579579
auto loc = insertOp.getLoc();
580580
auto position = insertOp.getMixedPosition();
581581

582-
Value source = insertOp.getSource();
582+
Value source = insertOp.getValueToStore();
583583

584584
// Overwrite entire vector with value. Should be handled by folder, but
585585
// just to be safe.

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -1257,7 +1257,7 @@ class VectorInsertOpConversion
12571257
// We are going to mutate this 1D vector until it is either the final
12581258
// result (in the non-aggregate case) or the value that needs to be
12591259
// inserted into the aggregate result.
1260-
Value sourceAggregate = adaptor.getSource();
1260+
Value sourceAggregate = adaptor.getValueToStore();
12611261
if (insertIntoInnermostDim) {
12621262
// Scalar-into-1D-vector case, so we know we will have to create a
12631263
// InsertElementOp. The question is into what destination.
@@ -1279,7 +1279,8 @@ class VectorInsertOpConversion
12791279
}
12801280
// Insert the scalar into the 1D vector.
12811281
sourceAggregate = rewriter.create<LLVM::InsertElementOp>(
1282-
loc, sourceAggregate.getType(), sourceAggregate, adaptor.getSource(),
1282+
loc, sourceAggregate.getType(), sourceAggregate,
1283+
adaptor.getValueToStore(),
12831284
getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
12841285
}
12851286

@@ -1305,7 +1306,7 @@ struct VectorScalableInsertOpLowering
13051306
matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
13061307
ConversionPatternRewriter &rewriter) const override {
13071308
rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1308-
insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
1309+
insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
13091310
return success();
13101311
}
13111312
};

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ struct PrepareTransferWriteConversion
661661
buffers.dataBuffer);
662662
auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
663663
rewriter.modifyOpInPlace(xferOp, [&]() {
664-
xferOp.getVectorMutable().assign(loadedVec);
664+
xferOp.getValueToStoreMutable().assign(loadedVec);
665665
xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
666666
});
667667

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -287,16 +287,16 @@ struct VectorInsertOpConvert final
287287
LogicalResult
288288
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
289289
ConversionPatternRewriter &rewriter) const override {
290-
if (isa<VectorType>(insertOp.getSourceType()))
290+
if (isa<VectorType>(insertOp.getValueToStoreType()))
291291
return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
292292
if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
293293
return rewriter.notifyMatchFailure(insertOp,
294294
"unsupported dest vector type");
295295

296296
// Special case for inserting scalar values into size-1 vectors.
297-
if (insertOp.getSourceType().isIntOrFloat() &&
297+
if (insertOp.getValueToStoreType().isIntOrFloat() &&
298298
insertOp.getDestVectorType().getNumElements() == 1) {
299-
rewriter.replaceOp(insertOp, adaptor.getSource());
299+
rewriter.replaceOp(insertOp, adaptor.getValueToStore());
300300
return success();
301301
}
302302

@@ -307,14 +307,15 @@ struct VectorInsertOpConvert final
307307
insertOp,
308308
"Static use of poison index handled elsewhere (folded to poison)");
309309
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
310-
insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
310+
insertOp, adaptor.getValueToStore(), adaptor.getDest(), id.value());
311311
} else {
312312
Value sanitizedIndex = sanitizeDynamicIndex(
313313
rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
314314
vector::InsertOp::kPoisonIndex,
315315
insertOp.getDestVectorType().getNumElements());
316316
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
317-
insertOp, insertOp.getDest(), adaptor.getSource(), sanitizedIndex);
317+
insertOp, insertOp.getDest(), adaptor.getValueToStore(),
318+
sanitizedIndex);
318319
}
319320
return success();
320321
}

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ struct LegalizeTransferWriteOpsByDecomposition
357357

358358
auto loc = writeOp.getLoc();
359359
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
360-
auto inputSMETiles = adaptor.getVector();
360+
auto inputSMETiles = adaptor.getValueToStore();
361361

362362
Value destTensorOrMemref = writeOp.getSource();
363363
for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
@@ -464,7 +464,7 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
464464
rewriter.setInsertionPointToStart(storeLoop.getBody());
465465

466466
// For each sub-tile of the multi-tile `vectorType`.
467-
auto inputSMETiles = adaptor.getVector();
467+
auto inputSMETiles = adaptor.getValueToStore();
468468
auto tileSliceIndex = storeLoop.getInductionVar();
469469
for (auto [index, smeTile] : llvm::enumerate(
470470
decomposeToSMETiles(rewriter, vectorType, smeTileType))) {

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
371371
if (failed(maybeNewLoop))
372372
return WalkResult::interrupt();
373373

374-
transferWrite.getVectorMutable().assign(
374+
transferWrite.getValueToStoreMutable().assign(
375375
maybeNewLoop->getOperation()->getResults().back());
376376
changed = true;
377377
// Need to interrupt and restart because erasing the loop messes up

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -3177,8 +3177,8 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
31773177
rewriter.create<vector::TransferWriteOp>(
31783178
xferOp.getLoc(), vector, out, xferOp.getIndices(),
31793179
xferOp.getPermutationMapAttr(), xferOp.getMask(),
3180-
rewriter.getBoolArrayAttr(
3181-
SmallVector<bool>(vector.getType().getRank(), false)));
3180+
rewriter.getBoolArrayAttr(SmallVector<bool>(
3181+
dyn_cast<VectorType>(vector.getType()).getRank(), false)));
31823182

31833183
rewriter.eraseOp(copyOp);
31843184
rewriter.eraseOp(xferOp);

0 commit comments

Comments
 (0)