Skip to content

Commit 0adaf2e

Browse files
[VectorExt] Add TransferScatterOp
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
1 parent fb7e890 commit 0adaf2e

4 files changed

Lines changed: 397 additions & 23 deletions

File tree

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.cpp

Lines changed: 78 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -62,35 +62,35 @@ void TransferGatherOp::getEffects(
6262
}
6363
}
6464

65-
// Verifier.
66-
67-
LogicalResult TransferGatherOp::verify() {
68-
OperandRange indexVecs = getIndexVecs();
69-
TypedValue<VectorType> vector = getVector();
70-
Value mask = getMask();
71-
SmallVector<AffineMap> indexingMaps = getIndexingMapsArray();
65+
// Shared verifier for TransferGatherOp and TransferScatterOp.
7266

67+
static LogicalResult
68+
verifyTransferGatherScatterLikeOp(Operation *op, VectorType vectorType,
69+
OperandRange indexVecs, Value mask,
70+
ArrayRef<AffineMap> indexingMaps) {
7371
// Check that we have the correct number of indexing maps.
7472
int64_t expectedNumIndexingMaps =
75-
/*sourceIndexingMap=*/1 + /*indexVecIndexingMaps=*/indexVecs.size() +
73+
/*baseIndexingMap=*/1 + /*indexVecIndexingMaps=*/indexVecs.size() +
7674
/*maskIndexingMap=*/(mask ? 1 : 0);
7775
if (expectedNumIndexingMaps != static_cast<int64_t>(indexingMaps.size())) {
78-
return emitOpError("expected ")
76+
return op->emitOpError("expected ")
7977
<< expectedNumIndexingMaps
8078
<< " indexing maps, got: " << indexingMaps.size();
8179
}
8280

83-
int64_t vectorRank = vector.getType().getRank();
81+
int64_t vectorRank = vectorType.getRank();
8482
int64_t indexSyms = indexVecs.size();
8583
for (AffineMap map : indexingMaps) {
8684
if (map.getNumDims() != vectorRank) {
87-
return emitOpError("expected all indexing maps to have number of dims "
88-
"equal to vector rank. expected: ")
85+
return op->emitOpError(
86+
"expected all indexing maps to have number of dims "
87+
"equal to vector rank. expected: ")
8988
<< vectorRank << ", got: " << map.getNumDims() << " dims";
9089
}
9190
if (map.getNumSymbols() != indexSyms) {
92-
return emitOpError("expected all indexing maps to have number of symbols "
93-
"equal to number of index vecs. expected: ")
91+
return op->emitOpError(
92+
"expected all indexing maps to have number of symbols "
93+
"equal to number of index vecs. expected: ")
9494
<< indexSyms << ", got: " << map.getNumSymbols() << " syms";
9595
}
9696
for (AffineExpr expr : map.getResults()) {
@@ -99,18 +99,18 @@ LogicalResult TransferGatherOp::verify() {
9999
}
100100
if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
101101
if (constExpr.getValue() != 0) {
102-
return emitOpError("expected constant 0 in indexing map, got: ")
102+
return op->emitOpError("expected constant 0 in indexing map, got: ")
103103
<< constExpr.getValue();
104104
}
105105
continue;
106106
}
107-
return emitOpError(
107+
return op->emitOpError(
108108
"expected indexing map results to only be a dim, symbol, or 0");
109109
}
110110
}
111111

112112
// Extra verification for index vecs.
113-
ArrayRef<int64_t> vectorShape = vector.getType().getShape();
113+
ArrayRef<int64_t> vectorShape = vectorType.getShape();
114114
ArrayRef<AffineMap> vectorIndexingMaps =
115115
ArrayRef(indexingMaps).slice(1, indexSyms);
116116
for (auto [i, map] : llvm::enumerate(vectorIndexingMaps)) {
@@ -119,23 +119,25 @@ LogicalResult TransferGatherOp::verify() {
119119
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
120120
expectedShape.push_back(vectorShape[dim.getPosition()]);
121121
} else {
122-
return emitOpError(
122+
return op->emitOpError(
123123
"expected vector indexing maps to not have any symbols");
124124
}
125125
}
126126
// Scalar index: map must have 0 results and type must be plain index.
127127
if (isa<IndexType>(indexVecs[i].getType())) {
128128
if (!expectedShape.empty()) {
129-
return emitOpError("expected empty indexing map for scalar index vec "
130-
"at position ")
129+
return op->emitOpError(
130+
"expected empty indexing map for scalar index vec "
131+
"at position ")
131132
<< i;
132133
}
133134
continue;
134135
}
135136
ArrayRef<int64_t> actualShape =
136137
cast<VectorType>(indexVecs[i].getType()).getShape();
137138
if (ArrayRef<int64_t>(expectedShape) != actualShape) {
138-
return emitOpError("Mismatched vector shape for index vec at position ")
139+
return op->emitOpError(
140+
"Mismatched vector shape for index vec at position ")
139141
<< i << ". Expected: [" << expectedShape << "]" << ", got: ["
140142
<< actualShape << "]";
141143
}
@@ -149,13 +151,13 @@ LogicalResult TransferGatherOp::verify() {
149151
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
150152
expectedShape.push_back(vectorShape[dim.getPosition()]);
151153
} else {
152-
return emitOpError(
154+
return op->emitOpError(
153155
"expected mask indexing map to not have any symbols");
154156
}
155157
}
156158
ArrayRef<int64_t> actualShape = cast<VectorType>(mask.getType()).getShape();
157159
if (ArrayRef<int64_t>(expectedShape) != actualShape) {
158-
return emitOpError("Mismatched mask shape")
160+
return op->emitOpError("Mismatched mask shape")
159161
<< ". Expected: [" << expectedShape << "]" << ", got: ["
160162
<< actualShape << "]";
161163
}
@@ -164,6 +166,12 @@ LogicalResult TransferGatherOp::verify() {
164166
return success();
165167
}
166168

169+
LogicalResult TransferGatherOp::verify() {
170+
return verifyTransferGatherScatterLikeOp(
171+
getOperation(), getVector().getType(), getIndexVecs(), getMask(),
172+
getIndexingMapsArray());
173+
}
174+
167175
// Fold and canonicalization helpers.
168176

169177
static int64_t getVectorRank(Type type) {
@@ -663,6 +671,53 @@ void TransferGatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
663671
ctx);
664672
}
665673

674+
//===----------------------------------------------------------------------===//
675+
// TransferScatterOp
676+
//===----------------------------------------------------------------------===//
677+
678+
Speculation::Speculatability TransferScatterOp::getSpeculatability() {
679+
if (isa<RankedTensorType>(getBase().getType())) {
680+
return Speculation::Speculatable;
681+
}
682+
return Speculation::NotSpeculatable;
683+
}
684+
685+
void TransferScatterOp::getEffects(
686+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
687+
&effects) {
688+
if (isa<MemRefType>(getBase().getType())) {
689+
effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
690+
SideEffects::DefaultResource::get());
691+
effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
692+
SideEffects::DefaultResource::get());
693+
}
694+
}
695+
696+
LogicalResult TransferScatterOp::verify() {
697+
if (failed(verifyTransferGatherScatterLikeOp(getOperation(), getVectorType(),
698+
getIndexVecs(), getMask(),
699+
getIndexingMapsArray()))) {
700+
return failure();
701+
}
702+
703+
// Verify result type matches base type for tensor semantics.
704+
if (hasTensorSemantics()) {
705+
if (!getResult()) {
706+
return emitOpError("expected result for tensor operand");
707+
}
708+
if (getResult().getType() != getBase().getType()) {
709+
return emitOpError("result type must match base type");
710+
}
711+
} else {
712+
// Memref semantics: no result expected.
713+
if (getResult()) {
714+
return emitOpError("unexpected result for memref operand");
715+
}
716+
}
717+
718+
return success();
719+
}
720+
666721
//===----------------------------------------------------------------------===//
667722
// YieldOp
668723
//===----------------------------------------------------------------------===//

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,127 @@ def IREEVectorExt_TransferGatherOp : IREEVectorExt_PureOp<"transfer_gather", [
224224
let hasVerifier = 1;
225225
}
226226

227+
def IREEVectorExt_TransferScatterOp : IREEVectorExt_PureOp<"transfer_scatter", [
228+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
229+
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
230+
AttrSizedOperandSegments,
231+
AllElementTypesMatch<["base", "vector"]>
232+
]> {
233+
let arguments = (ins AnyShaped:$base,
234+
AnyVectorOfAnyRank:$vector,
235+
Variadic<Index>:$offsets,
236+
Variadic<AnyTypeOf<[Index, VectorOfAnyRankOf<[Index]>]>>:$index_vecs,
237+
AffineMapArrayAttr:$indexing_maps,
238+
Optional<VectorOfAnyRankOf<[I1]>>:$mask);
239+
let results = (outs Optional<AnyShaped>:$result);
240+
241+
let summary = [{Scatters a supervector from an SSA vector value into a shaped destination.}];
242+
243+
let description = [{
244+
The `transfer_scatter` operation is the write counterpart of
245+
`transfer_gather`. It writes elements from a vector into a shaped
246+
destination (memref or tensor), where each destination dimension can be
247+
independently contiguous, scattered, or broadcast.
248+
249+
The scatter indices are expected to be unique. If multiple vector elements
250+
map to the same destination location, there may be data races.
251+
252+
For tensor operands, the operation returns the modified tensor. For memref
253+
operands, the operation has no result.
254+
255+
Example — scatter write: writing values into scattered rows of a 2D dest:
256+
257+
```
258+
// dest[indices[i], j] = vector[i, j]
259+
%result = iree_vector_ext.transfer_scatter %vector into %dest[%c0, %c0]
260+
[%indices : vector<16xindex>] {
261+
indexing_maps = [
262+
affine_map<(d0, d1)[s0] -> (s0, d1)>,
263+
affine_map<(d0, d1)[s0] -> (d0)>
264+
]
265+
} : vector<16x8xf16>, tensor<4096x8xf16>
266+
```
267+
268+
Semantically, for each position in the source vector:
269+
270+
```
271+
dest[offsets[0] + f0(d, s), offsets[1] + f1(d, s), ...] = vector[d0, d1, ...]
272+
```
273+
274+
where each `fi` is the i-th result of the dest indexing map evaluated at
275+
the vector position `d = (d0, d1, ...)` and scattered index values
276+
`s = (s0, s1, ...)`.
277+
278+
The `indexing_maps` attribute follows the same structure as
279+
`transfer_gather`. See `transfer_gather` documentation for details. The
280+
only difference is that Map 0 describes the destination indexing rather
281+
than the source indexing.
282+
}];
283+
284+
let extraClassDeclaration = [{
285+
SmallVector<AffineMap> getIndexingMapsArray() {
286+
return llvm::to_vector(getIndexingMaps().getAsValueRange<AffineMapAttr>());
287+
}
288+
289+
AffineMap getDestIndexingMap() {
290+
return getIndexingMapsArray().front();
291+
}
292+
293+
SmallVector<AffineMap> getIndexVecIndexingMaps() {
294+
auto maps = getIndexingMapsArray();
295+
return SmallVector<AffineMap>(
296+
maps.begin() + 1, maps.begin() + 1 + getIndexVecs().size());
297+
}
298+
299+
std::optional<AffineMap> getMaskIndexingMap() {
300+
if (!getMask()) return std::nullopt;
301+
return getIndexingMapsArray().back();
302+
}
303+
304+
VectorType getVectorType() {
305+
return getVector().getType();
306+
}
307+
308+
ShapedType getBaseType() {
309+
return cast<ShapedType>(getBase().getType());
310+
}
311+
312+
bool hasTensorSemantics() {
313+
return isa<RankedTensorType>(getBase().getType());
314+
}
315+
316+
/// Invert the dest indexing map to get a permutation map suitable for
317+
/// vector.transfer_write. The dest map is
318+
/// (vector_dims)[symbols] -> (dest_dims). This returns
319+
/// (dest_dims) -> (vector_dims), where non-dim results (scattered
320+
/// symbols, broadcast constants) become constant 0.
321+
AffineMap getPermutationMap() {
322+
AffineMap destMap = getDestIndexingMap();
323+
MLIRContext *ctx = getContext();
324+
SmallVector<AffineExpr> exprs(destMap.getNumDims(),
325+
getAffineConstantExpr(0, ctx));
326+
for (auto [i, expr] : llvm::enumerate(destMap.getResults())) {
327+
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
328+
exprs[dimExpr.getPosition()] = getAffineDimExpr(i, ctx);
329+
}
330+
}
331+
return AffineMap::get(destMap.getNumResults(), /*symbolCount=*/0,
332+
exprs, ctx);
333+
}
334+
}];
335+
336+
let assemblyFormat = [{
337+
$vector `into` $base `[` $offsets `]`
338+
(`[` $index_vecs^ `:` type($index_vecs) `]`)?
339+
(`,` $mask^)?
340+
attr-dict `:` type($vector) `,` type($base)
341+
(`,` type($mask)^)?
342+
(`->` type($result)^)?
343+
}];
344+
345+
let hasVerifier = 1;
346+
}
347+
227348
//===----------------------------------------------------------------------===//
228349
// Terminator ops.
229350
//===----------------------------------------------------------------------===//

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/invalid.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,59 @@ func.func @index_vec_shape_mismatch(%indices: vector<128x64xindex>,
8585

8686
// -----
8787

88+
func.func @scatter_wrong_num_indexing_maps(%indices: vector<128xindex>,
89+
%vector: vector<128xf16>,
90+
%dest: tensor<128xf16>)
91+
-> tensor<128xf16> {
92+
93+
%c0 = arith.constant 0 : index
94+
95+
// expected-error @+1 {{'iree_vector_ext.transfer_scatter' op expected 2 indexing maps, got: 1}}
96+
%out = iree_vector_ext.transfer_scatter %vector into %dest[%c0]
97+
[%indices : vector<128xindex>] {
98+
indexing_maps = [affine_map<(d0)[s0] -> (s0)>]
99+
} : vector<128xf16>, tensor<128xf16> -> tensor<128xf16>
100+
101+
return %out : tensor<128xf16>
102+
}
103+
104+
// -----
105+
106+
func.func @scatter_index_vec_shape_mismatch(%indices: vector<128x64xindex>,
107+
%vector: vector<128x64xf16>,
108+
%dest: tensor<128x64xf16>)
109+
-> tensor<128x64xf16> {
110+
111+
%c0 = arith.constant 0 : index
112+
113+
// expected-error @+1 {{'iree_vector_ext.transfer_scatter' op Mismatched vector shape for index vec at position 0. Expected: [64, 128], got: [128, 64]}}
114+
%out = iree_vector_ext.transfer_scatter %vector into %dest[%c0, %c0]
115+
[%indices : vector<128x64xindex>] {
116+
indexing_maps = [affine_map<(d0, d1)[s0] -> (d0, s0)>,
117+
affine_map<(d0, d1)[s0] -> (d1, d0)>]
118+
} : vector<128x64xf16>, tensor<128x64xf16> -> tensor<128x64xf16>
119+
120+
return %out : tensor<128x64xf16>
121+
}
122+
123+
// -----
124+
125+
func.func @scatter_memref_with_result(%vector: vector<128xf16>,
126+
%dest: memref<128xf16>)
127+
-> memref<128xf16> {
128+
129+
%c0 = arith.constant 0 : index
130+
131+
// expected-error @+1 {{'iree_vector_ext.transfer_scatter' op unexpected result for memref operand}}
132+
%out = iree_vector_ext.transfer_scatter %vector into %dest[%c0] {
133+
indexing_maps = [affine_map<(d0) -> (d0)>]
134+
} : vector<128xf16>, memref<128xf16> -> memref<128xf16>
135+
136+
return %out : memref<128xf16>
137+
}
138+
139+
// -----
140+
88141
func.func @arg_compare_dimension_out_of_bounds(%input: vector<4x128xf32>,
89142
%out_val: vector<4xf32>,
90143
%out_idx: vector<4xi32>)

0 commit comments

Comments
 (0)