Skip to content

Commit 8df2db9

Browse files
authored
[VectorDistribution] Relax layout size constraint (#23625)
Relax invariant on the nested layout attribute to allow the overall size of the layout to exceed the size of the underlying tensor. By allowing the layout to exceed the size of the tensor, we can select tile sizes friendly to hardware even if the tensor itself has an odd shape, i.e., a size not divisible by HW-friend tile sizes. The additional elements will be masked out by code generation. This change also makes sure that the vectorization of `to_layout` operations inserts masks on `transfer_read/write`. It also modifies the tensor layout configuration pass to use ceil-division in these cases to ensure full coverage of the tensor. This is part of #23415. Assisted-by: Claude Code --------- Signed-off-by: Lukas Sommer <lukas.sommer@amd.com>
1 parent 04784a9 commit 8df2db9

6 files changed

Lines changed: 111 additions & 29 deletions

File tree

compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ func.func @invalid_rank_nested_layout_anchor(%a: vector<16x16xf16>, %b: vector<1
617617
subgroup_tile = [1, 1],
618618
batch_tile = [2, 4],
619619
outer_tile = [1, 1],
620-
thread_tile = [8, 2],
620+
thread_tile = [2, 2],
621621
element_tile = [2, 2],
622622

623623
subgroup_strides = [0, 0],
@@ -628,7 +628,7 @@ func.func @invalid_rank_nested_layout_anchor(%a: vector<16x16xf16>, %b: vector<1
628628
func.func @invalid_size_nested_layout_anchor(%a: vector<16x16xf16>, %b: vector<16x16xf16>) -> vector<16x16xf16> {
629629
%c = arith.addf %a, %b : vector<16x16xf16>
630630
%cl = iree_vector_ext.to_layout %c to layout(#layout2) : vector<16x16xf16>
631-
// expected-error @above {{Vector shape: [16, 16] does not match the layout (nested_layout<subgroup_tile = [1, 1], batch_tile = [2, 4], outer_tile = [1, 1], thread_tile = [8, 2], element_tile = [2, 2], subgroup_strides = [0, 0], thread_strides = [1, 8]>) at dim 0. Dimension expected by layout: 32 actual: 16}}
631+
// expected-error @above {{Vector shape: [16, 16] does not match the layout (nested_layout<subgroup_tile = [1, 1], batch_tile = [2, 4], outer_tile = [1, 1], thread_tile = [2, 2], element_tile = [2, 2], subgroup_strides = [0, 0], thread_strides = [1, 8]>) at dim 0. Dimension expected by layout: 8 actual: 16}}
632632
func.return %cl : vector<16x16xf16>
633633
}
634634

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,11 @@ LogicalResult NestedLayoutAttr::isValidLayout(ShapedType shapeTy,
416416
<< shape.size() << ") does not match rank of layout (" << rank
417417
<< ").";
418418
}
419+
if (isa<RankedTensorType>(shapeTy)) {
420+
// We do not verify layout size for tensors, as we allow the layout size to
421+
// exceed the tensor size and handle that through padding/masking.
422+
return success();
423+
}
419424
// Multiply all shapes in the layout.
420425
for (int i = 0, e = rank; i < e; ++i) {
421426
int64_t expectedShape = getSubgroupTile()[i] * getBatchTile()[i] *

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,30 +29,29 @@ struct VectorizeToLayoutOpPattern final
2929
using Base::Base;
3030

3131
vector::TransferReadOp
32-
createReadOp(PatternRewriter &rewriter,
32+
createReadOp(ImplicitLocOpBuilder &builder,
3333
IREE::VectorExt::ToLayoutOp toLayoutOp) const {
34-
Location loc = toLayoutOp.getLoc();
3534
ShapedType inputTy = toLayoutOp.getType();
36-
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
37-
auto identityMap = rewriter.getMultiDimIdentityMap(inputTy.getRank());
35+
auto zero = arith::ConstantIndexOp::create(builder, 0);
36+
auto identityMap = builder.getMultiDimIdentityMap(inputTy.getRank());
3837
SmallVector<int64_t> readShape =
3938
toLayoutOp.getLayout().getUndistributedShape();
4039
Value mask = nullptr;
41-
if (!toLayoutOp.getType().hasStaticShape()) {
42-
SmallVector<OpFoldResult> mixedSourceDims =
43-
tensor::getMixedSizes(rewriter, loc, toLayoutOp.getInput());
44-
auto maskType = VectorType::get(readShape, rewriter.getI1Type());
45-
mask = vector::CreateMaskOp::create(rewriter, loc, maskType,
46-
mixedSourceDims);
40+
bool needsMask = !toLayoutOp.getType().hasStaticShape() ||
41+
(readShape != inputTy.getShape());
42+
if (needsMask) {
43+
SmallVector<OpFoldResult> mixedSourceDims = tensor::getMixedSizes(
44+
builder, builder.getLoc(), toLayoutOp.getInput());
45+
auto maskType = VectorType::get(readShape, builder.getI1Type());
46+
mask = vector::CreateMaskOp::create(builder, maskType, mixedSourceDims);
4747
}
4848
VectorType vectorType =
4949
VectorType::get(readShape, inputTy.getElementType());
50-
auto inBounds = rewriter.getBoolArrayAttr(
51-
SmallVector<bool>(vectorType.getRank(), true));
52-
auto padValue =
53-
ub::PoisonOp::create(rewriter, loc, inputTy.getElementType());
50+
auto inBounds =
51+
builder.getBoolArrayAttr(SmallVector<bool>(vectorType.getRank(), true));
52+
auto padValue = ub::PoisonOp::create(builder, inputTy.getElementType());
5453
auto read = vector::TransferReadOp::create(
55-
rewriter, loc,
54+
builder,
5655
/*type=*/vectorType,
5756
/*source=*/toLayoutOp.getInput(),
5857
/*indices=*/ValueRange{SmallVector<Value>(readShape.size(), zero)},
@@ -64,19 +63,18 @@ struct VectorizeToLayoutOpPattern final
6463
}
6564

6665
vector::TransferWriteOp
67-
createWriteOp(PatternRewriter &rewriter,
66+
createWriteOp(ImplicitLocOpBuilder &builder,
6867
IREE::VectorExt::ToLayoutOp tensorLayoutOp,
6968
Value vectorLayoutOp, Value mask) const {
70-
Location loc = tensorLayoutOp.getLoc();
7169
ShapedType tensorTy = tensorLayoutOp.getType();
7270
auto resType =
7371
RankedTensorType::get(tensorTy.getShape(), tensorTy.getElementType());
74-
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
72+
auto zero = arith::ConstantIndexOp::create(builder, 0);
7573
int64_t rank = tensorTy.getShape().size();
76-
auto inBounds = rewriter.getBoolArrayAttr(SmallVector<bool>(rank, true));
77-
auto identityMap = rewriter.getMultiDimIdentityMap(tensorTy.getRank());
74+
auto inBounds = builder.getBoolArrayAttr(SmallVector<bool>(rank, true));
75+
auto identityMap = builder.getMultiDimIdentityMap(tensorTy.getRank());
7876
return vector::TransferWriteOp::create(
79-
rewriter, loc,
77+
builder,
8078
/*result=*/resType,
8179
/*vector=*/vectorLayoutOp,
8280
/*source=*/tensorLayoutOp.getInput(),
@@ -94,14 +92,15 @@ struct VectorizeToLayoutOpPattern final
9492
OpBuilder::InsertionGuard g(rewriter);
9593
rewriter.setInsertionPoint(toLayoutOp);
9694
Location loc = toLayoutOp.getLoc();
97-
vector::TransferReadOp readOp = createReadOp(rewriter, toLayoutOp);
95+
ImplicitLocOpBuilder builder{loc, rewriter};
96+
vector::TransferReadOp readOp = createReadOp(builder, toLayoutOp);
9897
// Create the toLayout operation but with vector types instead.
9998
auto newLayoutOp = IREE::VectorExt::ToLayoutOp::create(
100-
rewriter, loc, readOp, toLayoutOp.getLayout(),
99+
builder, readOp, toLayoutOp.getLayout(),
101100
toLayoutOp.getSharedMemoryConversion());
102101
// Create the write back to a tensor.
103102
vector::TransferWriteOp writeOp =
104-
createWriteOp(rewriter, toLayoutOp, newLayoutOp, readOp.getMask());
103+
createWriteOp(builder, toLayoutOp, newLayoutOp, readOp.getMask());
105104
rewriter.replaceOp(toLayoutOp, writeOp);
106105
return success();
107106
}

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/vectorize_vector_ext_ops.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,28 @@ func.func @vectorize_to_layout(%A: tensor<64x64xf32>) -> tensor<64x64xf32> {
241241
// CHECK: %[[A_READ:.+]] = vector.transfer_read %[[AT]]
242242
// CHECK: %[[A:.+]] = iree_vector_ext.to_layout %[[A_READ]]
243243
// CHECK: %[[A_WRITE:.+]] = vector.transfer_write %[[A]], %[[AT]]
244+
245+
// -----
246+
247+
#layout = #iree_vector_ext.nested_layout<
248+
subgroup_tile = [1, 1],
249+
batch_tile = [4, 2],
250+
outer_tile = [1, 1],
251+
thread_tile = [8, 4],
252+
element_tile = [8, 8],
253+
254+
subgroup_strides = [0, 0],
255+
thread_strides = [4, 1]
256+
>
257+
258+
func.func @vectorize_to_layout_with_mask(%A: tensor<256x63xf32>) -> tensor<256x63xf32> {
259+
%AL = iree_vector_ext.to_layout %A to layout(#layout) : tensor<256x63xf32>
260+
return %AL : tensor<256x63xf32>
261+
}
262+
263+
// CHECK-LABEL: func.func @vectorize_to_layout_with_mask
264+
// CHECK-SAME: %[[AT:.+]]: tensor<256x63xf32>
265+
// CHECK: %[[MASK:.+]] = vector.constant_mask [256, 63]
266+
// CHECK: %[[A_READ:.+]] = vector.transfer_read %[[AT]]{{.*}} %[[MASK]]
267+
// CHECK: %[[A:.+]] = iree_vector_ext.to_layout %[[A_READ]]
268+
// CHECK: %[[A_WRITE:.+]] = vector.transfer_write %[[A]], %[[AT]]{{.*}} %[[MASK]]

compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ static IREE::Codegen::InnerTileDescAttrInterface getIntrinsic(Operation *op) {
6161
return mmaIntrinsic;
6262
}
6363

64-
/// Given two arrays bounds and tile, compute bounds /= tile.
64+
/// Given two arrays bounds and tile, compute bounds = ceil(bounds / tile).
6565
///
6666
/// If "tile" contains 0, or is smaller than bounds, divide bounds by 1
6767
/// for those values.
@@ -71,7 +71,7 @@ static IREE::Codegen::InnerTileDescAttrInterface getIntrinsic(Operation *op) {
7171
FailureOr<SmallVector<int64_t>> divideTile(SmallVector<int64_t> &bounds,
7272
ArrayRef<int64_t> tile) {
7373
assert(bounds.size() >= tile.size() &&
74-
"cannot divide bounds with a larger tile size");
74+
"cannot divide bounds with a different rank");
7575

7676
SmallVector<int64_t> divisor(bounds.size(), 1);
7777
for (auto [div, size] : llvm::zip(divisor, tile)) {
@@ -82,7 +82,7 @@ FailureOr<SmallVector<int64_t>> divideTile(SmallVector<int64_t> &bounds,
8282
}
8383

8484
for (auto [bound, div] : llvm::zip_equal(bounds, divisor)) {
85-
bound /= div;
85+
bound = llvm::divideCeil(bound, div);
8686
}
8787

8888
return divisor;

compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,3 +367,56 @@ func.func @dynamic_infer_sizes_lowering_config(%in : tensor<4x32x?x128xf16>) ->
367367
// CHECK: %[[EMPTYL:.+]] = iree_vector_ext.to_layout %[[EMPTY]] to layout(#[[LAYOUT]]) : tensor<1x1x?x128xf16>
368368
// CHECK: %[[COPY:.+]] = linalg.copy {{.*}} ins(%[[EXTRACTL]] : tensor<1x1x?x128xf16>) outs(%[[EMPTYL]] : tensor<1x1x?x128xf16>)
369369
// CHECK: iree_vector_ext.to_layout %[[COPY]] to layout(#[[LAYOUT]]) : tensor<1x1x?x128xf16>
370+
371+
// -----
372+
373+
// Verify that the batch tile for a dimension that requires ceil division
374+
// (63 / 8 = 8, not 7) is computed correctly.
375+
376+
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
377+
workgroup_size = [512, 1, 1]
378+
subgroup_size = 64>
379+
380+
#maps = [
381+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
382+
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
383+
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
384+
]
385+
386+
#traits = {
387+
indexing_maps = #maps,
388+
iterator_types = ["parallel", "parallel", "reduction", "parallel"],
389+
lowering_config = #iree_gpu.lowering_config<{
390+
lane_basis = [[1, 1, 1, 1, 64], [1, 0, 3, 4]],
391+
subgroup_basis = [[1, 1, 1, 1, 8], [0, 1, 2, 4]],
392+
thread = [0, 0, 8, 0]
393+
}>
394+
}
395+
396+
func.func @contraction_ceildiv_batch(%lhs: tensor<1x1x63xf16>,
397+
%rhs: tensor<1x512x63xf16>,
398+
%init: tensor<1x512x1xf32>)
399+
-> tensor<1x512x1xf32>
400+
attributes { translation_info = #translation } {
401+
%out = linalg.generic #traits
402+
ins(%lhs, %rhs: tensor<1x1x63xf16>, tensor<1x512x63xf16>)
403+
outs(%init: tensor<1x512x1xf32>) {
404+
^bb0(%in: f16, %in_1: f16, %out: f32):
405+
%ex = arith.extf %in : f16 to f32
406+
%ex_1 = arith.extf %in_1 : f16 to f32
407+
%mul = arith.mulf %ex, %ex_1 : f32
408+
%sum = arith.addf %mul, %out : f32
409+
linalg.yield %sum : f32
410+
} -> tensor<1x512x1xf32>
411+
return %out : tensor<1x512x1xf32>
412+
}
413+
414+
// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<{{.*}}batch_tile = [1, 1, 8]{{.*}}element_tile = [1, 1, 8]{{.*}}>
415+
// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout<{{.*}}batch_tile = [1, 1, 8]{{.*}}element_tile = [1, 1, 8]{{.*}}>
416+
417+
// CHECK-LABEL: func.func @contraction_ceildiv_batch
418+
419+
// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]])
420+
// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]])
421+
// CHECK: linalg.generic
422+
// CHECK-SAME: ins(%[[LHS]], %[[RHS]]

0 commit comments

Comments
 (0)