@@ -52,7 +52,7 @@ Value maybeAnd(RewriterBase &rewriter, Location loc, Value a, Value b) {
5252// according to freeVarsMask. The predicate may be null to indicate no
5353// predication is required.
5454Value emitRedundantThreadPredicate (
55- ModuleOp moduleOp, const llvm::MapVector<StringAttr, int32_t > &freeVarMasks,
55+ const llvm::MapVector<StringAttr, int32_t > &freeVarMasks,
5656 ConversionPatternRewriter &rewriter, Location loc,
5757 const NVIDIA::TargetInfo &targetInfo) {
5858 auto b = TritonLLVMOpBuilder (loc, rewriter);
@@ -452,14 +452,13 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
452452 << mask << " \n " ;
453453 }
454454
455- auto moduleOp = op->getParentOfType <ModuleOp>();
456455 const size_t dtsize =
457456 std::max<int >(1 , valueElemTy.getIntOrFloatBitWidth () / 8 );
458457 const size_t valueElemNBits = dtsize * 8 ;
459458
460459 auto freeVarMasks = getFreeVariableMasks (ptr.getType ());
461- Value threadPred = emitRedundantThreadPredicate (moduleOp, freeVarMasks,
462- rewriter, loc, targetInfo);
460+ Value threadPred =
461+ emitRedundantThreadPredicate (freeVarMasks, rewriter, loc, targetInfo);
463462 uint32_t regMask = freeVarMasks[str_attr (" reg" )];
464463
465464 const int numVecs = elemsPerThread / vec;
@@ -615,8 +614,8 @@ struct AtomicCASOpConversion
615614 << " elemsPerThread = " << elemsPerThread << " \n " ;
616615
617616 auto freeVarMasks = getFreeVariableMasks (op.getPtr ().getType ());
618- Value threadPred = emitRedundantThreadPredicate (moduleOp, freeVarMasks,
619- rewriter, loc, targetInfo);
617+ Value threadPred =
618+ emitRedundantThreadPredicate (freeVarMasks, rewriter, loc, targetInfo);
620619 uint32_t regMask = freeVarMasks[str_attr (" reg" )];
621620
622621 auto vecTy = vec_ty (valueElemTy, vec);
@@ -831,8 +830,8 @@ struct AtomicRMWOpConversion
831830 << " numElems = " << numElems;
832831
833832 auto freeVarMasks = getFreeVariableMasks (ptr.getType ());
834- Value threadPred = emitRedundantThreadPredicate (moduleOp, freeVarMasks,
835- rewriter, loc, targetInfo);
833+ Value threadPred =
834+ emitRedundantThreadPredicate (freeVarMasks, rewriter, loc, targetInfo);
836835 uint32_t regMask = freeVarMasks[str_attr (" reg" )];
837836
838837 auto packedTy = vec_ty (valueElemTy, packed);
@@ -1112,14 +1111,13 @@ struct AsyncCopyGlobalToLocalOpConversion
11121111 << vecBytes << " bytes" ;
11131112 }
11141113
1115- auto moduleOp = op->getParentOfType <ModuleOp>();
11161114 auto freeVarMasks = getFreeVariableMasks (srcTy);
11171115 // NOTE(@peterbell10): We load redundant data on different CTAs, so the data
11181116 // is available in each CTAs respective shared memory. Otherwise, we would
11191117 // need an additional broadcast step to copy the data between CTAs.
11201118 freeVarMasks[str_attr (" block" )] = 0 ;
1121- Value threadPred = emitRedundantThreadPredicate (moduleOp, freeVarMasks,
1122- rewriter, loc, targetInfo);
1119+ Value threadPred =
1120+ emitRedundantThreadPredicate (freeVarMasks, rewriter, loc, targetInfo);
11231121 uint32_t regMask = freeVarMasks[str_attr (" reg" )];
11241122
11251123 for (int i = 0 ; i < shmemAddrs.size (); i++) {
0 commit comments