Skip to content

Commit 8e17f80

Browse files
authored
[MLIR][XeGPU] Distribute vector.step & vector.shape_cast op from wg to sg (llvm#155443)
This PR adds patterns to distribute vector.step and vector.shape_cast op from wg to sg and it also enables constant, broadcast and elementwise ops to handle the slice attribute
1 parent 32620c5 commit 8e17f80

File tree

2 files changed

+208
-31
lines changed

2 files changed

+208
-31
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 149 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ struct WgToSgVectorBroadcastOp
468468
LogicalResult
469469
matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
470470
ConversionPatternRewriter &rewriter) const override {
471+
471472
VectorType resultType = op.getResult().getType();
472473
ArrayRef<int64_t> wgShape = resultType.getShape();
473474

@@ -476,43 +477,24 @@ struct WgToSgVectorBroadcastOp
476477
if (!layout || !layout.isForWorkgroup())
477478
return failure();
478479

479-
// TODO: Currently only supports cases where the source and result ranks
480-
// are the same.
481-
auto srcType =
482-
dyn_cast<VectorType>(adaptor.getOperands().front()[0].getType());
483-
if (!srcType || srcType.getRank() != resultType.getRank())
484-
return failure();
485-
486480
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
487481
VectorType newResultType =
488482
VectorType::get(sgShape, resultType.getElementType());
489483

490-
// Check if the output layout is distributable
491-
SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
492-
if (sgLayout.empty())
493-
return failure();
494-
495484
if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
496485
return failure();
497486

498-
// Check if the srcShape has unit dim in dimensions being broadcasted,
499-
// and the other dimensions are the same as the destination type
500-
// TODO: Generalize it
501-
auto srcShape = srcType.getShape();
502-
for (size_t i = 0; i < srcShape.size(); ++i) {
503-
if (srcShape[i] != 1 && srcShape[i] != sgShape[i])
504-
return failure();
505-
}
506-
507487
SmallVector<Value> newBroadcastOps;
508488
for (auto operand : adaptor.getOperands().front()) {
509489
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
510490
newResultType, operand);
511-
xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
512-
layout.dropSgLayoutAndData());
491+
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
492+
!layout.getEffectiveInstDataAsInt().empty())
493+
xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
494+
layout.dropSgLayoutAndData());
495+
513496
newBroadcastOps.push_back(newBroadcast.getResult());
514497
}
515-
516498
rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
517499
return success();
518500
}
@@ -564,9 +546,11 @@ struct WgToSgElementwiseOp : public ConversionPattern {
564546
// Copy all attributes, but update "layout_result_0" to drop
565547
// sgLayout/sgData
566548
for (auto attr : op->getAttrs()) {
567-
if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue())) {
568-
if (auto newLayout = layout.dropSgLayoutAndData())
569-
state.addAttribute(attr.getName(), newLayout);
549+
if (auto layout =
550+
dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
551+
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
552+
!layout.getEffectiveInstDataAsInt().empty())
553+
state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
570554
} else {
571555
state.addAttribute(attr.getName(), attr.getValue());
572556
}
@@ -757,8 +741,10 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
757741
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
758742
auto cstOp =
759743
arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
760-
if (auto newLayout = layout.dropSgLayoutAndData())
761-
xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
744+
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
745+
!layout.getEffectiveInstDataAsInt().empty())
746+
xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
747+
layout.dropSgLayoutAndData());
762748
SmallVector<Value> newConsts(count, cstOp);
763749

764750
rewriter.replaceOpWithMultiple(op, {newConsts});
@@ -919,6 +905,128 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
919905
}
920906
};
921907

908+
// This pattern distributes the vector.step ops to work at subgroup level
909+
struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
910+
using OpConversionPattern<vector::StepOp>::OpConversionPattern;
911+
LogicalResult
912+
matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
913+
ConversionPatternRewriter &rewriter) const override {
914+
xegpu::DistributeLayoutAttr layout =
915+
xegpu::getDistributeLayoutAttr(op.getResult());
916+
if (!layout || !layout.isForWorkgroup())
917+
return failure();
918+
919+
Location loc = op.getLoc();
920+
VectorType type = op.getResult().getType();
921+
auto wgShape = type.getShape();
922+
std::optional<SmallVector<int64_t>> sgShape =
923+
getSgShapeAndCount(wgShape, layout).first;
924+
if (!sgShape)
925+
return failure();
926+
927+
Value sgId =
928+
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
929+
auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
930+
if (failed(sgOffsets))
931+
return failure();
932+
933+
VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
934+
auto steps = vector::StepOp::create(rewriter, loc, newTy);
935+
SmallVector<Value> newOps;
936+
for (auto offsets : *sgOffsets) {
937+
// Broadcast the offset scalar to a vector & add to the base steps
938+
auto bcastOffset =
939+
vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
940+
auto finalSteps =
941+
arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
942+
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
943+
!layout.getEffectiveInstDataAsInt().empty()) {
944+
xegpu::setDistributeLayoutAttr(steps->getResult(0),
945+
layout.dropSgLayoutAndData());
946+
xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
947+
layout.dropSgLayoutAndData());
948+
xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
949+
layout.dropSgLayoutAndData());
950+
}
951+
newOps.push_back(finalSteps);
952+
}
953+
954+
rewriter.replaceOpWithMultiple(op, {newOps});
955+
return success();
956+
}
957+
};
958+
959+
// This pattern transforms vector.shape_cast ops to work at subgroup level.
960+
struct WgToSgVectorShapeCastOp
961+
: public OpConversionPattern<vector::ShapeCastOp> {
962+
using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
963+
964+
LogicalResult
965+
matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
966+
ConversionPatternRewriter &rewriter) const override {
967+
968+
VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
969+
if (!resultType)
970+
return failure();
971+
972+
ArrayRef<int64_t> wgShape = resultType.getShape();
973+
xegpu::DistributeLayoutAttr layout =
974+
xegpu::getDistributeLayoutAttr(op.getResult());
975+
if (!layout || !layout.isForWorkgroup())
976+
return failure();
977+
978+
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
979+
VectorType newResultType =
980+
VectorType::get(sgShape, resultType.getElementType());
981+
982+
// TODO: Add check for compatible layouts in layout attr.
983+
auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
984+
if (!srcType)
985+
return failure();
986+
987+
// Check that shape_cast only adds/removes unit dimensions,
988+
auto onlyUnitDims = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
989+
// Remove all 1s from both shapes and compare the rest.
990+
SmallVector<int64_t> srcNonUnit, dstNonUnit;
991+
for (int64_t d : src)
992+
if (d != 1)
993+
srcNonUnit.push_back(d);
994+
for (int64_t d : dst)
995+
if (d != 1)
996+
dstNonUnit.push_back(d);
997+
return srcNonUnit == dstNonUnit;
998+
};
999+
1000+
if (!onlyUnitDims(srcType.getShape(), sgShape))
1001+
return failure();
1002+
1003+
// For rank reducing or increasing shape_cast ops, the lower rank layout
1004+
// must be a slice of higher rank layout.
1005+
int64_t sourceRank = srcType.getRank();
1006+
int64_t resultRank = sgShape.size();
1007+
xegpu::DistributeLayoutAttr sourceLayout =
1008+
xegpu::getDistributeLayoutAttr(op.getSource());
1009+
if (sourceRank < resultRank && !sourceLayout.isSliceOf(layout))
1010+
return failure();
1011+
if (sourceRank > resultRank && !layout.isSliceOf(sourceLayout))
1012+
return failure();
1013+
1014+
SmallVector<Value> newShapeCastOps;
1015+
for (auto src : adaptor.getSource()) {
1016+
auto newShapeCast =
1017+
rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
1018+
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
1019+
!layout.getEffectiveInstDataAsInt().empty())
1020+
xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
1021+
layout.dropSgLayoutAndData());
1022+
newShapeCastOps.push_back(newShapeCast.getResult());
1023+
}
1024+
1025+
rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1026+
return success();
1027+
}
1028+
};
1029+
9221030
} // namespace
9231031

9241032
namespace mlir {
@@ -932,7 +1040,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
9321040
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
9331041
WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
9341042
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
935-
WgToSgStoreMatrixOp>(patterns.getContext());
1043+
WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp>(
1044+
patterns.getContext());
9361045
}
9371046
} // namespace xegpu
9381047
} // namespace mlir
@@ -1054,7 +1163,16 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
10541163
auto vecType = dyn_cast<VectorType>(op.getType());
10551164
if (!vecType)
10561165
return true;
1057-
return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
1166+
1167+
auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
1168+
return isLegal(layout);
1169+
});
1170+
1171+
target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
1172+
[=](Operation *op) -> bool {
1173+
// Check for either a SliceAttr or LayoutAttr on the result.
1174+
auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
1175+
return isLegal(layout);
10581176
});
10591177

10601178
target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
44
//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
5+
//CHECK: #map2 = affine_map<()[s0] -> (s0 floordiv 8)>
56
gpu.module @test_distribution {
67
// CHECK-LABEL: create_nd_tdesc_no_offset
78
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -365,4 +366,62 @@ gpu.module @test_distribution {
365366
xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
366367
gpu.return
367368
}
369+
370+
// CHECK-LABEL: vector_step_op
371+
gpu.func @vector_step_op_slice_attr() {
372+
//CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
373+
//CHECK-DAG: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]]
374+
//CHECK-DAG: [[c32:%.+]] = arith.constant 32 : index
375+
//CHECK-DAG: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
376+
//CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
377+
//CHECK-DAG: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
378+
//CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
379+
//CHECK-DAG: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
380+
//CHECK-DAG: [[BASE:%.+]] = vector.step : vector<32xindex>
381+
//CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
382+
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
383+
%step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
384+
gpu.return
385+
}
386+
387+
gpu.func @vector_step_op_layout_attr() {
388+
//CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
389+
//CHECK-DAG: [[c16:%.+]] = arith.constant 16 : index
390+
//CHECK-DAG: [[c8:%.+]] = arith.constant 8 : index
391+
//CHECK-DAG: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]]
392+
//CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
393+
//CHECK-DAG: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
394+
//CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
395+
//CHECK-DAG: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
396+
//CHECK-DAG: [[BASE:%.+]] = vector.step : vector<8xindex>
397+
//CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
398+
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex>
399+
%step = vector.step {layout_result_0 = #xegpu.layout<sg_layout = [16], sg_data = [8]>}: vector<128xindex>
400+
gpu.return
401+
}
402+
403+
// CHECK-LABEL: constant_with_slice_attr
404+
gpu.func @constant_with_slice_attr() {
405+
//CHECK: [[cst:%.+]] = arith.constant dense<10> : vector<1xindex>
406+
%cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 1]>, dims = [1, 2, 3]>} dense<10> : vector<4xindex>
407+
gpu.return
408+
}
409+
410+
// CHECK-LABEL: vector_shape_cast
411+
gpu.func @vector_shape_cast() {
412+
%cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, dims = [0, 1, 2]>} dense<10> : vector<128xindex>
413+
%step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, dims = [0, 1, 2]>} : vector<128xindex>
414+
%muli = arith.muli %cst, %step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, dims = [0, 1, 2]>} : vector<128xindex>
415+
//CHECK: vector.shape_cast {{.*}} : vector<32xindex> to vector<1x1x1x32xindex>
416+
%shape_cast = vector.shape_cast %muli {layout_result_0 = #xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>} : vector<128xindex> to vector<1x1x1x128xindex>
417+
gpu.return
418+
}
419+
420+
// CHECK-LABEL: vector_broadcast
421+
gpu.func @vector_broadcast(%arg0: index, %arg1: index) {
422+
%muli = arith.muli %arg0, %arg1 : index
423+
// CHECK: vector.broadcast {{.*}} : index to vector<1x1x1x32xindex>
424+
%broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : index to vector<4x2x6x32xindex>
425+
gpu.return
426+
}
368427
}

0 commit comments

Comments
 (0)