@@ -62,35 +62,35 @@ void TransferGatherOp::getEffects(
6262 }
6363}
6464
65- // Verifier.
66-
67- LogicalResult TransferGatherOp::verify () {
68- OperandRange indexVecs = getIndexVecs ();
69- TypedValue<VectorType> vector = getVector ();
70- Value mask = getMask ();
71- SmallVector<AffineMap> indexingMaps = getIndexingMapsArray ();
65+ // Shared verifier for TransferGatherOp and TransferScatterOp.
7266
67+ static LogicalResult
68+ verifyTransferGatherScatterLikeOp (Operation *op, VectorType vectorType,
69+ OperandRange indexVecs, Value mask,
70+ ArrayRef<AffineMap> indexingMaps) {
7371 // Check that we have the correct number of indexing maps.
7472 int64_t expectedNumIndexingMaps =
75- /* sourceIndexingMap =*/ 1 + /* indexVecIndexingMaps=*/ indexVecs.size () +
73+ /* baseIndexingMap =*/ 1 + /* indexVecIndexingMaps=*/ indexVecs.size () +
7674 /* maskIndexingMap=*/ (mask ? 1 : 0 );
7775 if (expectedNumIndexingMaps != static_cast <int64_t >(indexingMaps.size ())) {
78- return emitOpError (" expected " )
76+ return op-> emitOpError (" expected " )
7977 << expectedNumIndexingMaps
8078 << " indexing maps, got: " << indexingMaps.size ();
8179 }
8280
83- int64_t vectorRank = vector. getType () .getRank ();
81+ int64_t vectorRank = vectorType .getRank ();
8482 int64_t indexSyms = indexVecs.size ();
8583 for (AffineMap map : indexingMaps) {
8684 if (map.getNumDims () != vectorRank) {
87- return emitOpError (" expected all indexing maps to have number of dims "
88- " equal to vector rank. expected: " )
85+ return op->emitOpError (
86+ " expected all indexing maps to have number of dims "
87+ " equal to vector rank. expected: " )
8988 << vectorRank << " , got: " << map.getNumDims () << " dims" ;
9089 }
9190 if (map.getNumSymbols () != indexSyms) {
92- return emitOpError (" expected all indexing maps to have number of symbols "
93- " equal to number of index vecs. expected: " )
91+ return op->emitOpError (
92+ " expected all indexing maps to have number of symbols "
93+ " equal to number of index vecs. expected: " )
9494 << indexSyms << " , got: " << map.getNumSymbols () << " syms" ;
9595 }
9696 for (AffineExpr expr : map.getResults ()) {
@@ -99,18 +99,18 @@ LogicalResult TransferGatherOp::verify() {
9999 }
100100 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
101101 if (constExpr.getValue () != 0 ) {
102- return emitOpError (" expected constant 0 in indexing map, got: " )
102+ return op-> emitOpError (" expected constant 0 in indexing map, got: " )
103103 << constExpr.getValue ();
104104 }
105105 continue ;
106106 }
107- return emitOpError (
107+ return op-> emitOpError (
108108 " expected indexing map results to only be a dim, symbol, or 0" );
109109 }
110110 }
111111
112112 // Extra verification for index vecs.
113- ArrayRef<int64_t > vectorShape = vector. getType () .getShape ();
113+ ArrayRef<int64_t > vectorShape = vectorType .getShape ();
114114 ArrayRef<AffineMap> vectorIndexingMaps =
115115 ArrayRef (indexingMaps).slice (1 , indexSyms);
116116 for (auto [i, map] : llvm::enumerate (vectorIndexingMaps)) {
@@ -119,23 +119,25 @@ LogicalResult TransferGatherOp::verify() {
119119 if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
120120 expectedShape.push_back (vectorShape[dim.getPosition ()]);
121121 } else {
122- return emitOpError (
122+ return op-> emitOpError (
123123 " expected vector indexing maps to not have any symbols" );
124124 }
125125 }
126126 // Scalar index: map must have 0 results and type must be plain index.
127127 if (isa<IndexType>(indexVecs[i].getType ())) {
128128 if (!expectedShape.empty ()) {
129- return emitOpError (" expected empty indexing map for scalar index vec "
130- " at position " )
129+ return op->emitOpError (
130+ " expected empty indexing map for scalar index vec "
131+ " at position " )
131132 << i;
132133 }
133134 continue ;
134135 }
135136 ArrayRef<int64_t > actualShape =
136137 cast<VectorType>(indexVecs[i].getType ()).getShape ();
137138 if (ArrayRef<int64_t >(expectedShape) != actualShape) {
138- return emitOpError (" Mismatched vector shape for index vec at position " )
139+ return op->emitOpError (
140+ " Mismatched vector shape for index vec at position " )
139141 << i << " . Expected: [" << expectedShape << " ]" << " , got: ["
140142 << actualShape << " ]" ;
141143 }
@@ -149,13 +151,13 @@ LogicalResult TransferGatherOp::verify() {
149151 if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
150152 expectedShape.push_back (vectorShape[dim.getPosition ()]);
151153 } else {
152- return emitOpError (
154+ return op-> emitOpError (
153155 " expected mask indexing map to not have any symbols" );
154156 }
155157 }
156158 ArrayRef<int64_t > actualShape = cast<VectorType>(mask.getType ()).getShape ();
157159 if (ArrayRef<int64_t >(expectedShape) != actualShape) {
158- return emitOpError (" Mismatched mask shape" )
160+ return op-> emitOpError (" Mismatched mask shape" )
159161 << " . Expected: [" << expectedShape << " ]" << " , got: ["
160162 << actualShape << " ]" ;
161163 }
@@ -164,6 +166,12 @@ LogicalResult TransferGatherOp::verify() {
164166 return success ();
165167}
166168
169+ LogicalResult TransferGatherOp::verify () {
170+ return verifyTransferGatherScatterLikeOp (
171+ getOperation (), getVector ().getType (), getIndexVecs (), getMask (),
172+ getIndexingMapsArray ());
173+ }
174+
167175// Fold and canonicalization helpers.
168176
169177static int64_t getVectorRank (Type type) {
@@ -663,6 +671,53 @@ void TransferGatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
663671 ctx);
664672}
665673
674+ // ===----------------------------------------------------------------------===//
675+ // TransferScatterOp
676+ // ===----------------------------------------------------------------------===//
677+
678+ Speculation::Speculatability TransferScatterOp::getSpeculatability () {
679+ return Speculation::NotSpeculatable;
680+ }
681+
682+ void TransferScatterOp::getEffects (
683+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
684+ &effects) {
685+ if (isa<MemRefType>(getBase ().getType ())) {
686+ effects.emplace_back (MemoryEffects::Read::get (), &getBaseMutable (),
687+ SideEffects::DefaultResource::get ());
688+ effects.emplace_back (MemoryEffects::Write::get (), &getBaseMutable (),
689+ SideEffects::DefaultResource::get ());
690+ }
691+ }
692+
693+ LogicalResult TransferScatterOp::verify () {
694+ if (failed (verifyTransferGatherScatterLikeOp (getOperation (), getVectorType (),
695+ getIndexVecs (), getMask (),
696+ getIndexingMapsArray ()))) {
697+ return failure ();
698+ }
699+
700+ // Verify result type matches base type for tensor semantics.
701+ if (hasTensorSemantics ()) {
702+ if (!getResult ()) {
703+ return emitOpError (" expected result for tensor operand" );
704+ }
705+ if (getResult ().getType () != getBase ().getType ()) {
706+ return emitOpError (" result type must match base type" );
707+ }
708+ }
709+
710+ return success ();
711+ }
712+
713+ LogicalResult TransferScatterOp::fold (FoldAdaptor adaptor,
714+ SmallVectorImpl<OpFoldResult> &results) {
715+ return failure ();
716+ }
717+
718+ void TransferScatterOp::getCanonicalizationPatterns (RewritePatternSet &results,
719+ MLIRContext *ctx) {}
720+
666721// ===----------------------------------------------------------------------===//
667722// YieldOp
668723// ===----------------------------------------------------------------------===//
0 commit comments