Skip to content

Commit 4b9697a

Browse files
[VectorExt] Add TransferScatterOp
Introduce `iree_vector_ext.transfer_scatter`, the write counterpart to `transfer_gather`. Uses the same unified `indexing_maps` attribute with per-dimension control (contiguous, scattered, broadcast). For tensor operands the op returns the modified tensor; for memref operands it has no result. Declarative assemblyFormat handles the optional result. Includes shared verifier extracted from TransferGatherOp. Fold and canonicalization are stubs; populated in a follow-up. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent fb7e890 commit 4b9697a

4 files changed

Lines changed: 331 additions & 23 deletions

File tree

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.cpp

Lines changed: 78 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -62,35 +62,35 @@ void TransferGatherOp::getEffects(
6262
}
6363
}
6464

65-
// Verifier.
66-
67-
LogicalResult TransferGatherOp::verify() {
68-
OperandRange indexVecs = getIndexVecs();
69-
TypedValue<VectorType> vector = getVector();
70-
Value mask = getMask();
71-
SmallVector<AffineMap> indexingMaps = getIndexingMapsArray();
65+
// Shared verifier for TransferGatherOp and TransferScatterOp.
7266

67+
static LogicalResult
68+
verifyTransferGatherScatterLikeOp(Operation *op, VectorType vectorType,
69+
OperandRange indexVecs, Value mask,
70+
ArrayRef<AffineMap> indexingMaps) {
7371
// Check that we have the correct number of indexing maps.
7472
int64_t expectedNumIndexingMaps =
75-
/*sourceIndexingMap=*/1 + /*indexVecIndexingMaps=*/indexVecs.size() +
73+
/*baseIndexingMap=*/1 + /*indexVecIndexingMaps=*/indexVecs.size() +
7674
/*maskIndexingMap=*/(mask ? 1 : 0);
7775
if (expectedNumIndexingMaps != static_cast<int64_t>(indexingMaps.size())) {
78-
return emitOpError("expected ")
76+
return op->emitOpError("expected ")
7977
<< expectedNumIndexingMaps
8078
<< " indexing maps, got: " << indexingMaps.size();
8179
}
8280

83-
int64_t vectorRank = vector.getType().getRank();
81+
int64_t vectorRank = vectorType.getRank();
8482
int64_t indexSyms = indexVecs.size();
8583
for (AffineMap map : indexingMaps) {
8684
if (map.getNumDims() != vectorRank) {
87-
return emitOpError("expected all indexing maps to have number of dims "
88-
"equal to vector rank. expected: ")
85+
return op->emitOpError(
86+
"expected all indexing maps to have number of dims "
87+
"equal to vector rank. expected: ")
8988
<< vectorRank << ", got: " << map.getNumDims() << " dims";
9089
}
9190
if (map.getNumSymbols() != indexSyms) {
92-
return emitOpError("expected all indexing maps to have number of symbols "
93-
"equal to number of index vecs. expected: ")
91+
return op->emitOpError(
92+
"expected all indexing maps to have number of symbols "
93+
"equal to number of index vecs. expected: ")
9494
<< indexSyms << ", got: " << map.getNumSymbols() << " syms";
9595
}
9696
for (AffineExpr expr : map.getResults()) {
@@ -99,18 +99,18 @@ LogicalResult TransferGatherOp::verify() {
9999
}
100100
if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
101101
if (constExpr.getValue() != 0) {
102-
return emitOpError("expected constant 0 in indexing map, got: ")
102+
return op->emitOpError("expected constant 0 in indexing map, got: ")
103103
<< constExpr.getValue();
104104
}
105105
continue;
106106
}
107-
return emitOpError(
107+
return op->emitOpError(
108108
"expected indexing map results to only be a dim, symbol, or 0");
109109
}
110110
}
111111

112112
// Extra verification for index vecs.
113-
ArrayRef<int64_t> vectorShape = vector.getType().getShape();
113+
ArrayRef<int64_t> vectorShape = vectorType.getShape();
114114
ArrayRef<AffineMap> vectorIndexingMaps =
115115
ArrayRef(indexingMaps).slice(1, indexSyms);
116116
for (auto [i, map] : llvm::enumerate(vectorIndexingMaps)) {
@@ -119,23 +119,25 @@ LogicalResult TransferGatherOp::verify() {
119119
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
120120
expectedShape.push_back(vectorShape[dim.getPosition()]);
121121
} else {
122-
return emitOpError(
122+
return op->emitOpError(
123123
"expected vector indexing maps to not have any symbols");
124124
}
125125
}
126126
// Scalar index: map must have 0 results and type must be plain index.
127127
if (isa<IndexType>(indexVecs[i].getType())) {
128128
if (!expectedShape.empty()) {
129-
return emitOpError("expected empty indexing map for scalar index vec "
130-
"at position ")
129+
return op->emitOpError(
130+
"expected empty indexing map for scalar index vec "
131+
"at position ")
131132
<< i;
132133
}
133134
continue;
134135
}
135136
ArrayRef<int64_t> actualShape =
136137
cast<VectorType>(indexVecs[i].getType()).getShape();
137138
if (ArrayRef<int64_t>(expectedShape) != actualShape) {
138-
return emitOpError("Mismatched vector shape for index vec at position ")
139+
return op->emitOpError(
140+
"Mismatched vector shape for index vec at position ")
139141
<< i << ". Expected: [" << expectedShape << "]" << ", got: ["
140142
<< actualShape << "]";
141143
}
@@ -149,13 +151,13 @@ LogicalResult TransferGatherOp::verify() {
149151
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
150152
expectedShape.push_back(vectorShape[dim.getPosition()]);
151153
} else {
152-
return emitOpError(
154+
return op->emitOpError(
153155
"expected mask indexing map to not have any symbols");
154156
}
155157
}
156158
ArrayRef<int64_t> actualShape = cast<VectorType>(mask.getType()).getShape();
157159
if (ArrayRef<int64_t>(expectedShape) != actualShape) {
158-
return emitOpError("Mismatched mask shape")
160+
return op->emitOpError("Mismatched mask shape")
159161
<< ". Expected: [" << expectedShape << "]" << ", got: ["
160162
<< actualShape << "]";
161163
}
@@ -164,6 +166,12 @@ LogicalResult TransferGatherOp::verify() {
164166
return success();
165167
}
166168

169+
LogicalResult TransferGatherOp::verify() {
170+
return verifyTransferGatherScatterLikeOp(
171+
getOperation(), getVector().getType(), getIndexVecs(), getMask(),
172+
getIndexingMapsArray());
173+
}
174+
167175
// Fold and canonicalization helpers.
168176

169177
static int64_t getVectorRank(Type type) {
@@ -663,6 +671,53 @@ void TransferGatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
663671
ctx);
664672
}
665673

674+
//===----------------------------------------------------------------------===//
675+
// TransferScatterOp
676+
//===----------------------------------------------------------------------===//
677+
678+
Speculation::Speculatability TransferScatterOp::getSpeculatability() {
679+
return Speculation::NotSpeculatable;
680+
}
681+
682+
void TransferScatterOp::getEffects(
683+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
684+
&effects) {
685+
if (isa<MemRefType>(getBase().getType())) {
686+
effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
687+
SideEffects::DefaultResource::get());
688+
effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
689+
SideEffects::DefaultResource::get());
690+
}
691+
}
692+
693+
LogicalResult TransferScatterOp::verify() {
694+
if (failed(verifyTransferGatherScatterLikeOp(getOperation(), getVectorType(),
695+
getIndexVecs(), getMask(),
696+
getIndexingMapsArray()))) {
697+
return failure();
698+
}
699+
700+
// Verify result type matches base type for tensor semantics.
701+
if (hasTensorSemantics()) {
702+
if (!getResult()) {
703+
return emitOpError("expected result for tensor operand");
704+
}
705+
if (getResult().getType() != getBase().getType()) {
706+
return emitOpError("result type must match base type");
707+
}
708+
}
709+
710+
return success();
711+
}
712+
713+
LogicalResult TransferScatterOp::fold(FoldAdaptor adaptor,
714+
SmallVectorImpl<OpFoldResult> &results) {
715+
return failure();
716+
}
717+
718+
void TransferScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
719+
MLIRContext *ctx) {}
720+
666721
//===----------------------------------------------------------------------===//
667722
// YieldOp
668723
//===----------------------------------------------------------------------===//

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,129 @@ def IREEVectorExt_TransferGatherOp : IREEVectorExt_PureOp<"transfer_gather", [
224224
let hasVerifier = 1;
225225
}
226226

227+
def IREEVectorExt_TransferScatterOp : IREEVectorExt_PureOp<"transfer_scatter", [
228+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
229+
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
230+
AttrSizedOperandSegments,
231+
AllElementTypesMatch<["base", "vector"]>
232+
]> {
233+
let arguments = (ins AnyShaped:$base,
234+
AnyVectorOfAnyRank:$vector,
235+
Variadic<Index>:$offsets,
236+
Variadic<AnyTypeOf<[Index, VectorOfAnyRankOf<[Index]>]>>:$index_vecs,
237+
AffineMapArrayAttr:$indexing_maps,
238+
Optional<VectorOfAnyRankOf<[I1]>>:$mask);
239+
let results = (outs Optional<AnyShaped>:$result);
240+
241+
let summary = [{Scatters a supervector from an SSA vector value into a shaped destination.}];
242+
243+
let description = [{
244+
The `transfer_scatter` operation writes elements from a vector into a shaped
245+
destination (memref or tensor), where each destination dimension can be
246+
independently contiguous, scattered, or broadcast.
247+
248+
Semantically, for each position in the source vector:
249+
250+
```
251+
dest[offsets[0] + f0(d, s), offsets[1] + f1(d, s), ...] = vector[d0, d1, ...]
252+
```
253+
254+
where each `fi` is the i-th result of the dest indexing map evaluated at
255+
the vector position `d = (d0, d1, ...)` and scattered index values
256+
`s = (s0, s1, ...)`.
257+
258+
The `indexing_maps` attribute describes all indexing. Every map has
259+
`numDims = source vector rank` and `numSymbols = number of index vecs`:
260+
261+
- Map 0 (dest map): `(vector_dims)[symbols] -> (dest_dims)`.
262+
A dim expr means the dest dimension is contiguous (iterated in
263+
lockstep with the vector dimension). A symbol expr means the dest
264+
dimension is scattered (written via the corresponding index vector).
265+
A constant 0 means the dest dimension is broadcast (always writes
266+
at the base offset — multiple vector elements write to the same
267+
destination location).
268+
- Maps 1..N (index vec maps): `(vector_dims)[symbols] -> (index_vec_dims)`.
269+
Describes how each index vector is indexed from the vector iteration
270+
space. Only dim exprs are allowed.
271+
- Optional last map (mask map): present only when a mask operand is
272+
provided. Only dim exprs are allowed.
273+
274+
For tensor operands, the operation returns the modified tensor. For memref
275+
operands, the operation has no result.
276+
277+
Example — scatter write: writing values into scattered rows of a 2D dest:
278+
279+
```
280+
// dest[indices[i], j] = vector[i, j]
281+
%result = iree_vector_ext.transfer_scatter %vector into %dest[%c0, %c0]
282+
[%indices : vector<16xindex>] {
283+
indexing_maps = [
284+
affine_map<(d0, d1)[s0] -> (s0, d1)>,
285+
affine_map<(d0, d1)[s0] -> (d0)>
286+
]
287+
} : vector<16x8xf16>, tensor<4096x8xf16>
288+
```
289+
}];
290+
291+
let extraClassDeclaration = [{
292+
SmallVector<AffineMap> getIndexingMapsArray() {
293+
return llvm::to_vector(getIndexingMaps().getAsValueRange<AffineMapAttr>());
294+
}
295+
296+
AffineMap getDestIndexingMap() {
297+
return getIndexingMapsArray().front();
298+
}
299+
300+
SmallVector<AffineMap> getIndexVecIndexingMaps() {
301+
auto maps = getIndexingMapsArray();
302+
return SmallVector<AffineMap>(
303+
maps.begin() + 1, maps.begin() + 1 + getIndexVecs().size());
304+
}
305+
306+
std::optional<AffineMap> getMaskIndexingMap() {
307+
if (!getMask()) return std::nullopt;
308+
return getIndexingMapsArray().back();
309+
}
310+
311+
VectorType getVectorType() {
312+
return getVector().getType();
313+
}
314+
315+
ShapedType getBaseType() {
316+
return cast<ShapedType>(getBase().getType());
317+
}
318+
319+
bool hasTensorSemantics() {
320+
return isa<RankedTensorType>(getBase().getType());
321+
}
322+
323+
/// Invert the dest indexing map to get a permutation map suitable for
324+
/// vector.transfer_write. The dest map is
325+
/// (vector_dims)[symbols] -> (dest_dims). This returns
326+
/// (dest_dims) -> (vector_dims), where non-dim results (scattered
327+
/// symbols, broadcast constants) become constant 0.
328+
AffineMap getPermutationMap() {
329+
AffineMap destMap = getDestIndexingMap();
330+
MLIRContext *ctx = getContext();
331+
SmallVector<AffineExpr> exprs(destMap.getNumDims(),
332+
getAffineConstantExpr(0, ctx));
333+
for (auto [i, expr] : llvm::enumerate(destMap.getResults())) {
334+
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
335+
exprs[dimExpr.getPosition()] = getAffineDimExpr(i, ctx);
336+
}
337+
}
338+
return AffineMap::get(destMap.getNumResults(), /*symbolCount=*/0,
339+
exprs, ctx);
340+
}
341+
}];
342+
343+
let assemblyFormat = "$vector `into` $base `[` $offsets `]` (`[` $index_vecs^ `:` type($index_vecs) `]`)? (`,` $mask^)? attr-dict `:` type($vector) `,` type($base) (`,` type($mask)^)? (`->` type($result)^)?";
344+
345+
let hasCanonicalizer = 1;
346+
let hasFolder = 1;
347+
let hasVerifier = 1;
348+
}
349+
227350
//===----------------------------------------------------------------------===//
228351
// Terminator ops.
229352
//===----------------------------------------------------------------------===//

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/invalid.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,43 @@ func.func @index_vec_shape_mismatch(%indices: vector<128x64xindex>,
8585

8686
// -----
8787

88+
func.func @scatter_wrong_num_indexing_maps(%indices: vector<128xindex>,
89+
%vector: vector<128xf16>,
90+
%dest: tensor<128xf16>)
91+
-> tensor<128xf16> {
92+
93+
%c0 = arith.constant 0 : index
94+
95+
// expected-error @+1 {{'iree_vector_ext.transfer_scatter' op expected 2 indexing maps, got: 1}}
96+
%out = iree_vector_ext.transfer_scatter %vector into %dest[%c0]
97+
[%indices : vector<128xindex>] {
98+
indexing_maps = [affine_map<(d0)[s0] -> (s0)>]
99+
} : vector<128xf16>, tensor<128xf16> -> tensor<128xf16>
100+
101+
return %out : tensor<128xf16>
102+
}
103+
104+
// -----
105+
106+
func.func @scatter_index_vec_shape_mismatch(%indices: vector<128x64xindex>,
107+
%vector: vector<128x64xf16>,
108+
%dest: tensor<128x64xf16>)
109+
-> tensor<128x64xf16> {
110+
111+
%c0 = arith.constant 0 : index
112+
113+
// expected-error @+1 {{'iree_vector_ext.transfer_scatter' op Mismatched vector shape for index vec at position 0. Expected: [64, 128], got: [128, 64]}}
114+
%out = iree_vector_ext.transfer_scatter %vector into %dest[%c0, %c0]
115+
[%indices : vector<128x64xindex>] {
116+
indexing_maps = [affine_map<(d0, d1)[s0] -> (d0, s0)>,
117+
affine_map<(d0, d1)[s0] -> (d1, d0)>]
118+
} : vector<128x64xf16>, tensor<128x64xf16> -> tensor<128x64xf16>
119+
120+
return %out : tensor<128x64xf16>
121+
}
122+
123+
// -----
124+
88125
func.func @arg_compare_dimension_out_of_bounds(%input: vector<4x128xf32>,
89126
%out_val: vector<4xf32>,
90127
%out_idx: vector<4xi32>)

0 commit comments

Comments
 (0)