@@ -468,6 +468,7 @@ struct WgToSgVectorBroadcastOp
468468 LogicalResult
469469 matchAndRewrite (vector::BroadcastOp op, OneToNOpAdaptor adaptor,
470470 ConversionPatternRewriter &rewriter) const override {
471+
471472 VectorType resultType = op.getResult ().getType ();
472473 ArrayRef<int64_t > wgShape = resultType.getShape ();
473474
@@ -476,43 +477,24 @@ struct WgToSgVectorBroadcastOp
476477 if (!layout || !layout.isForWorkgroup ())
477478 return failure ();
478479
479- // TODO: Currently only supports cases where the source and result ranks
480- // are the same.
481- auto srcType =
482- dyn_cast<VectorType>(adaptor.getOperands ().front ()[0 ].getType ());
483- if (!srcType || srcType.getRank () != resultType.getRank ())
484- return failure ();
485-
486480 SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
487481 VectorType newResultType =
488482 VectorType::get (sgShape, resultType.getElementType ());
489483
490- // Check if the output layout is distributable
491- SmallVector<int64_t > sgLayout = layout.getEffectiveSgLayoutAsInt ();
492- if (sgLayout.empty ())
493- return failure ();
494-
495484 if (!xegpu::XeGPUDialect::isEvenlyDistributable (wgShape, layout))
496485 return failure ();
497486
498- // Check if the srcShape has unit dim in dimensions being broadcasted,
499- // and the other dimensions are the same as the destination type
500- // TODO: Generalize it
501- auto srcShape = srcType.getShape ();
502- for (size_t i = 0 ; i < srcShape.size (); ++i) {
503- if (srcShape[i] != 1 && srcShape[i] != sgShape[i])
504- return failure ();
505- }
506-
507487 SmallVector<Value> newBroadcastOps;
508488 for (auto operand : adaptor.getOperands ().front ()) {
509489 auto newBroadcast = vector::BroadcastOp::create (rewriter, op.getLoc (),
510490 newResultType, operand);
511- xegpu::setDistributeLayoutAttr (newBroadcast->getResult (0 ),
512- layout.dropSgLayoutAndData ());
491+ if (!layout.getEffectiveLaneLayoutAsInt ().empty () ||
492+ !layout.getEffectiveInstDataAsInt ().empty ())
493+ xegpu::setDistributeLayoutAttr (newBroadcast->getResult (0 ),
494+ layout.dropSgLayoutAndData ());
495+
513496 newBroadcastOps.push_back (newBroadcast.getResult ());
514497 }
515-
516498 rewriter.replaceOpWithMultiple (op, {newBroadcastOps});
517499 return success ();
518500 }
@@ -564,9 +546,11 @@ struct WgToSgElementwiseOp : public ConversionPattern {
564546 // Copy all attributes, but update "layout_result_0" to drop
565547 // sgLayout/sgData
566548 for (auto attr : op->getAttrs ()) {
567- if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue ())) {
568- if (auto newLayout = layout.dropSgLayoutAndData ())
569- state.addAttribute (attr.getName (), newLayout);
549+ if (auto layout =
550+ dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue ())) {
551+ if (!layout.getEffectiveLaneLayoutAsInt ().empty () ||
552+ !layout.getEffectiveInstDataAsInt ().empty ())
553+ state.addAttribute (attr.getName (), layout.dropSgLayoutAndData ());
570554 } else {
571555 state.addAttribute (attr.getName (), attr.getValue ());
572556 }
@@ -757,8 +741,10 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
757741 auto sgAttr = DenseElementsAttr::get (newType, singleVal);
758742 auto cstOp =
759743 arith::ConstantOp::create (rewriter, op.getLoc (), newType, sgAttr);
760- if (auto newLayout = layout.dropSgLayoutAndData ())
761- xegpu::setDistributeLayoutAttr (cstOp->getResult (0 ), newLayout);
744+ if (!layout.getEffectiveLaneLayoutAsInt ().empty () ||
745+ !layout.getEffectiveInstDataAsInt ().empty ())
746+ xegpu::setDistributeLayoutAttr (cstOp->getResult (0 ),
747+ layout.dropSgLayoutAndData ());
762748 SmallVector<Value> newConsts (count, cstOp);
763749
764750 rewriter.replaceOpWithMultiple (op, {newConsts});
@@ -919,6 +905,128 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
919905 }
920906};
921907
908+ // This pattern distributes the vector.step ops to work at subgroup level
909+ struct WgToSgVectorStepOp : public OpConversionPattern <vector::StepOp> {
910+ using OpConversionPattern<vector::StepOp>::OpConversionPattern;
911+ LogicalResult
912+ matchAndRewrite (vector::StepOp op, OneToNOpAdaptor adaptor,
913+ ConversionPatternRewriter &rewriter) const override {
914+ xegpu::DistributeLayoutAttr layout =
915+ xegpu::getDistributeLayoutAttr (op.getResult ());
916+ if (!layout || !layout.isForWorkgroup ())
917+ return failure ();
918+
919+ Location loc = op.getLoc ();
920+ VectorType type = op.getResult ().getType ();
921+ auto wgShape = type.getShape ();
922+ std::optional<SmallVector<int64_t >> sgShape =
923+ getSgShapeAndCount (wgShape, layout).first ;
924+ if (!sgShape)
925+ return failure ();
926+
927+ Value sgId =
928+ gpu::SubgroupIdOp::create (rewriter, loc, /* upper_bound=*/ nullptr );
929+ auto sgOffsets = layout.getOffsets (rewriter, loc, sgId, wgShape);
930+ if (failed (sgOffsets))
931+ return failure ();
932+
933+ VectorType newTy = type.cloneWith (*sgShape, type.getElementType ());
934+ auto steps = vector::StepOp::create (rewriter, loc, newTy);
935+ SmallVector<Value> newOps;
936+ for (auto offsets : *sgOffsets) {
937+ // Broadcast the offset scalar to a vector & add to the base steps
938+ auto bcastOffset =
939+ vector::BroadcastOp::create (rewriter, loc, newTy, offsets[0 ]);
940+ auto finalSteps =
941+ arith::AddIOp::create (rewriter, loc, steps, bcastOffset);
942+ if (!layout.getEffectiveLaneLayoutAsInt ().empty () ||
943+ !layout.getEffectiveInstDataAsInt ().empty ()) {
944+ xegpu::setDistributeLayoutAttr (steps->getResult (0 ),
945+ layout.dropSgLayoutAndData ());
946+ xegpu::setDistributeLayoutAttr (bcastOffset->getResult (0 ),
947+ layout.dropSgLayoutAndData ());
948+ xegpu::setDistributeLayoutAttr (finalSteps->getResult (0 ),
949+ layout.dropSgLayoutAndData ());
950+ }
951+ newOps.push_back (finalSteps);
952+ }
953+
954+ rewriter.replaceOpWithMultiple (op, {newOps});
955+ return success ();
956+ }
957+ };
958+
959+ // This pattern transforms vector.shape_cast ops to work at subgroup level.
960+ struct WgToSgVectorShapeCastOp
961+ : public OpConversionPattern<vector::ShapeCastOp> {
962+ using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
963+
964+ LogicalResult
965+ matchAndRewrite (vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
966+ ConversionPatternRewriter &rewriter) const override {
967+
968+ VectorType resultType = dyn_cast<VectorType>(op.getResult ().getType ());
969+ if (!resultType)
970+ return failure ();
971+
972+ ArrayRef<int64_t > wgShape = resultType.getShape ();
973+ xegpu::DistributeLayoutAttr layout =
974+ xegpu::getDistributeLayoutAttr (op.getResult ());
975+ if (!layout || !layout.isForWorkgroup ())
976+ return failure ();
977+
978+ SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
979+ VectorType newResultType =
980+ VectorType::get (sgShape, resultType.getElementType ());
981+
982+ // TODO: Add check for compatible layouts in layout attr.
983+ auto srcType = dyn_cast<VectorType>(adaptor.getSource ()[0 ].getType ());
984+ if (!srcType)
985+ return failure ();
986+
987+ // Check that shape_cast only adds/removes unit dimensions,
988+ auto onlyUnitDims = [](ArrayRef<int64_t > src, ArrayRef<int64_t > dst) {
989+ // Remove all 1s from both shapes and compare the rest.
990+ SmallVector<int64_t > srcNonUnit, dstNonUnit;
991+ for (int64_t d : src)
992+ if (d != 1 )
993+ srcNonUnit.push_back (d);
994+ for (int64_t d : dst)
995+ if (d != 1 )
996+ dstNonUnit.push_back (d);
997+ return srcNonUnit == dstNonUnit;
998+ };
999+
1000+ if (!onlyUnitDims (srcType.getShape (), sgShape))
1001+ return failure ();
1002+
1003+ // For rank reducing or increasing shape_cast ops, the lower rank layout
1004+ // must be a slice of higher rank layout.
1005+ int64_t sourceRank = srcType.getRank ();
1006+ int64_t resultRank = sgShape.size ();
1007+ xegpu::DistributeLayoutAttr sourceLayout =
1008+ xegpu::getDistributeLayoutAttr (op.getSource ());
1009+ if (sourceRank < resultRank && !sourceLayout.isSliceOf (layout))
1010+ return failure ();
1011+ if (sourceRank > resultRank && !layout.isSliceOf (sourceLayout))
1012+ return failure ();
1013+
1014+ SmallVector<Value> newShapeCastOps;
1015+ for (auto src : adaptor.getSource ()) {
1016+ auto newShapeCast =
1017+ rewriter.create <vector::ShapeCastOp>(op.getLoc (), newResultType, src);
1018+ if (!layout.getEffectiveLaneLayoutAsInt ().empty () ||
1019+ !layout.getEffectiveInstDataAsInt ().empty ())
1020+ xegpu::setDistributeLayoutAttr (newShapeCast->getResult (0 ),
1021+ layout.dropSgLayoutAndData ());
1022+ newShapeCastOps.push_back (newShapeCast.getResult ());
1023+ }
1024+
1025+ rewriter.replaceOpWithMultiple (op, {newShapeCastOps});
1026+ return success ();
1027+ }
1028+ };
1029+
9221030} // namespace
9231031
9241032namespace mlir {
@@ -932,7 +1040,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
9321040 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
9331041 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
9341042 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
935- WgToSgStoreMatrixOp>(patterns.getContext ());
1043+ WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp>(
1044+ patterns.getContext ());
9361045}
9371046} // namespace xegpu
9381047} // namespace mlir
@@ -1054,7 +1163,16 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
10541163 auto vecType = dyn_cast<VectorType>(op.getType ());
10551164 if (!vecType)
10561165 return true ;
1057- return isLegal (xegpu::getDistributeLayoutAttr (op.getResult ()));
1166+
1167+ auto layout = xegpu::getDistributeLayoutAttr (op.getResult ());
1168+ return isLegal (layout);
1169+ });
1170+
1171+ target.addDynamicallyLegalOp <vector::ShapeCastOp, vector::StepOp>(
1172+ [=](Operation *op) -> bool {
1173+ // Check for either a SliceAttr or LayoutAttr on the result.
1174+ auto layout = xegpu::getDistributeLayoutAttr (op->getResult (0 ));
1175+ return isLegal (layout);
10581176 });
10591177
10601178 target.addDynamicallyLegalOp <xegpu::LoadGatherOp>(
0 commit comments