Skip to content

Commit ce50eac

Browse files
authored
[Backend][NFC] FuncOpToLLVM: Move handleArgPtrDatatype to Utility.h (#9120)
Make handleArgPtrDatatype a utility function to avoid code duplication.
1 parent 1d5a827 commit ce50eac

4 files changed

Lines changed: 16 additions & 34 deletions

File tree

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,7 @@ void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs,
641641
triton::FuncOp amendFuncOp(triton::FuncOp funcOp,
642642
ConversionPatternRewriter &rewriter,
643643
const TargetInfoBase &targetInfo);
644+
void handleArgPtrDatatype(triton::FuncOp funcOp, LLVM::LLVMFuncOp &llvmFuncOp);
644645
} // namespace mlir
645646

646647
#endif

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -65,23 +65,6 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
6565
}
6666
}
6767

68-
static void handleArgPtrDatatype(triton::FuncOp funcOp,
69-
LLVM::LLVMFuncOp &llvmFuncOp) {
70-
71-
// The convertion from triton::PointerType to LLVM::LLVMPointerType losts
72-
// the pointee datatype. This function add the pointee datatype to arg
73-
// attribute.
74-
FunctionType fty = funcOp.getFunctionType();
75-
for (unsigned i = 0; i < fty.getNumInputs(); ++i) {
76-
auto argType = fty.getInput(i);
77-
if (auto argPtrType = dyn_cast<triton::PointerType>(argType)) {
78-
auto argDType = argPtrType.getPointeeType();
79-
llvmFuncOp.setArgAttr(i, "tt.pointee_type",
80-
mlir::TypeAttr::get(argDType));
81-
}
82-
}
83-
}
84-
8568
LogicalResult
8669
matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor,
8770
ConversionPatternRewriter &rewriter) const override {

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1639,4 +1639,19 @@ triton::FuncOp amendFuncOp(triton::FuncOp funcOp,
16391639
return amendedFuncOp;
16401640
}
16411641

1642+
void handleArgPtrDatatype(triton::FuncOp funcOp, LLVM::LLVMFuncOp &llvmFuncOp) {
1643+
// The convertion from triton::PointerType to LLVM::LLVMPointerType losts
1644+
// the pointee datatype information.
1645+
// This function add back the pointee datatype information to arg attribute.
1646+
FunctionType fty = funcOp.getFunctionType();
1647+
for (unsigned i = 0; i < fty.getNumInputs(); ++i) {
1648+
auto argType = fty.getInput(i);
1649+
if (auto argPtrType = dyn_cast<triton::PointerType>(argType)) {
1650+
auto argDType = argPtrType.getPointeeType();
1651+
llvmFuncOp.setArgAttr(i, "tt.pointee_type",
1652+
mlir::TypeAttr::get(argDType));
1653+
}
1654+
}
1655+
}
1656+
16421657
} // namespace mlir

third_party/amd/lib/TritonAMDGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,6 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
1010
const TargetInfoBase &targetInfo, PatternBenefit benefit)
1111
: ConvertOpToLLVMPattern(converter, benefit), targetInfo(targetInfo) {}
1212

13-
static void handleArgPtrDatatype(triton::FuncOp funcOp,
14-
LLVM::LLVMFuncOp &llvmFuncOp) {
15-
16-
// The convertion from triton::PointerType to LLVM::LLVMPointerType losts
17-
// the pointee datatype information.
18-
// This function add back the pointee datatype information to arg attribute.
19-
FunctionType fty = funcOp.getFunctionType();
20-
for (unsigned i = 0; i < fty.getNumInputs(); ++i) {
21-
auto argType = fty.getInput(i);
22-
if (auto argPtrType = dyn_cast<triton::PointerType>(argType)) {
23-
auto argDType = argPtrType.getPointeeType();
24-
llvmFuncOp.setArgAttr(i, "tt.pointee_type",
25-
mlir::TypeAttr::get(argDType));
26-
}
27-
}
28-
}
29-
3013
LogicalResult
3114
matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor,
3215
ConversionPatternRewriter &rewriter) const override {

0 commit comments

Comments
 (0)