@@ -885,6 +885,69 @@ bool GatherLoweringHelper::isWarpLocal() {
885885 idxLayout.sublayout (kLane , otherDims);
886886}
887887
888+ ScatterLoweringHelper::ScatterLoweringHelper (triton::ScatterOp scatterOp)
889+ : scatterOp(scatterOp) {}
890+
891+ unsigned ScatterLoweringHelper::getScratchSizeInBytes () {
892+ // Otherwise, scattering will write into a temporary output tensor in shared
893+ // memory before materializing the final output registers.
894+ RankedTensorType dstType = scatterOp.getDst ().getType ();
895+ unsigned elemBytes = ceil<unsigned >(dstType.getElementTypeBitWidth (), 8 );
896+ unsigned dstBytes = product (dstType.getShape ()) * elemBytes;
897+ bool hasCombine =
898+ !scatterOp.getCombineOp ().empty () || scatterOp.getReduceKindAttr ();
899+ if (hasCombine) {
900+ // Extra i32 per element for CAS-based locks/flags.
901+ dstBytes += product (dstType.getShape ()) * sizeof (int32_t );
902+ }
903+ return dstBytes;
904+ }
905+
906+ bool ScatterLoweringHelper::isWarpLocal () {
907+ // The scatter is warp-local if all source/index writes for any destination
908+ // column can be serviced within a single warp.
909+ RankedTensorType dstType = scatterOp.getDst ().getType ();
910+ RankedTensorType srcType = scatterOp.getSrc ().getType ();
911+ RankedTensorType idxType = scatterOp.getIndices ().getType ();
912+ LinearLayout dstLayout = toLinearLayout (dstType);
913+ LinearLayout srcLayout = toLinearLayout (srcType);
914+ LinearLayout idxLayout = toLinearLayout (idxType);
915+
916+ Builder b (scatterOp.getContext ());
917+ StringAttr kBlock = b.getStringAttr (" block" );
918+ StringAttr kWarp = b.getStringAttr (" warp" );
919+ StringAttr kLane = b.getStringAttr (" lane" );
920+ StringAttr kScatterDim =
921+ b.getStringAttr (" dim" + std::to_string (scatterOp.getAxis ()));
922+
923+ // The scatter dimension must be invariant with respect to warp/block in all
924+ // participating tensors.
925+ if (!dstLayout.sublayoutIsZero ({kBlock , kWarp }, kScatterDim ) ||
926+ !srcLayout.sublayoutIsZero ({kBlock , kWarp }, kScatterDim ) ||
927+ !idxLayout.sublayoutIsZero ({kBlock , kWarp }, kScatterDim ))
928+ return false ;
929+
930+ SmallVector<StringAttr> otherDims;
931+ for (unsigned dim = 0 , rank = dstType.getRank (); dim < rank; ++dim) {
932+ if (dim != scatterOp.getAxis ()) {
933+ otherDims.push_back (b.getStringAttr (" dim" + Twine (dim)));
934+ }
935+ }
936+
937+ // Source/index and destination columns must line up identically across warps.
938+ if (dstLayout.sublayout ({kBlock , kWarp }, otherDims) !=
939+ srcLayout.sublayout ({kBlock , kWarp }, otherDims) ||
940+ dstLayout.sublayout ({kBlock , kWarp }, otherDims) !=
941+ idxLayout.sublayout ({kBlock , kWarp }, otherDims))
942+ return false ;
943+
944+ // Require lane ownership of columns to match for simpler codegen.
945+ return dstLayout.sublayout (kLane , otherDims) ==
946+ srcLayout.sublayout (kLane , otherDims) &&
947+ dstLayout.sublayout (kLane , otherDims) ==
948+ idxLayout.sublayout (kLane , otherDims);
949+ }
950+
888951unsigned getNumScratchElements (ArrayRef<unsigned > shape) {
889952 if (shape.empty ())
890953 return 0 ;
0 commit comments