@@ -264,8 +264,8 @@ getWarpsPerTile(DotOpInterface dotOp, const ArrayRef<int64_t> shape,
264264static bool bwdFilter (Operation *op) {
265265 return (op->hasTrait <OpTrait::Elementwise>() && isMemoryEffectFree (op)) ||
266266 isView (op) ||
267- isa<Fp4ToFpOp, LoadOp, DescriptorLoadOp , BroadcastOp, ConvertLayoutOp>(
268- op);
267+ isa<Fp4ToFpOp, LoadOp, DescriptorLoadLikeOpInterface , BroadcastOp,
268+ ConvertLayoutOp>( op);
269269}
270270
271271// Finds the bitwidth with which the value x is loaded
@@ -284,7 +284,7 @@ static int computeOrigBitWidth(Value x) {
284284
285285 int origBitWidth = getElementTypeOrSelf (x).getIntOrFloatBitWidth ();
286286 for (auto op : slice) {
287- if (isa<LoadOp, DescriptorLoadOp >(op)) {
287+ if (isa<LoadOp, DescriptorLoadLikeOpInterface >(op)) {
288288 if (auto tensorTy =
289289 dyn_cast<RankedTensorType>(op->getResultTypes ().front ())) {
290290 origBitWidth =
@@ -473,8 +473,9 @@ static bool canUseTwoCTAs(triton::DotOp dotOp) {
473473 // Skip convert layouts.
474474 while (auto cvtOp = b.getDefiningOp <ConvertLayoutOp>())
475475 b = cvtOp.getSrc ();
476- return llvm::isa_and_nonnull<triton::LoadOp, triton::DescriptorLoadOp,
477- triton::DescriptorGatherOp>(b.getDefiningOp ());
476+ return llvm::isa_and_nonnull<triton::LoadOp,
477+ triton::DescriptorLoadLikeOpInterface>(
478+ b.getDefiningOp ());
478479}
479480
480481static DistributedEncodingTrait
@@ -501,8 +502,7 @@ static Value splitBOperand(Value b, mlir::PatternRewriter &rewriter) {
501502 while (auto cvtOp = b.getDefiningOp <ConvertLayoutOp>())
502503 b = cvtOp.getSrc ();
503504 auto loadOp = b.getDefiningOp ();
504- assert ((isa<triton::LoadOp, triton::DescriptorLoadOp,
505- triton::DescriptorGatherOp>(loadOp)) &&
505+ assert ((isa<triton::LoadOp, triton::DescriptorLoadLikeOpInterface>(loadOp)) &&
506506 " expected LoadOp" );
507507 RankedTensorType bType = cast<RankedTensorType>(b.getType ());
508508 auto currentLayout = cast<DistributedEncodingTrait>(bType.getEncoding ());
@@ -627,7 +627,7 @@ Value addSmemStageToScaleLoad(Value scale, mlir::PatternRewriter &rewriter) {
627627 if (!op)
628628 return scale;
629629
630- while (!isa<LoadOp, DescriptorLoadOp >(op)) {
630+ while (!isa<LoadOp, DescriptorLoadLikeOpInterface >(op)) {
631631 if (auto reshape = dyn_cast<ReshapeOp>(op)) {
632632 op = reshape.getSrc ().getDefiningOp ();
633633 loadConsumer = reshape;
0 commit comments