Skip to content

Commit 170cd7e

Browse files
sjw36CRobeck
authored andcommitted
[IR] Update names of decomposed tensor descriptor args (triton-lang#9587)
Tensor descriptor function arguments are decomposed into tt.ptr, shape, stride, and some flags. This pr updates each new parameter with a meaningful name, aligned with codegen'd tensordesc naming.
1 parent a81b247 commit 170cd7e

4 files changed

Lines changed: 139 additions & 7 deletions

File tree

include/triton/Dialect/Triton/Transforms/FunctionTypeConversion.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,65 @@
11
#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_FUNCTION_TYPE_CONVERSION_H_
22
#define TRITON_DIALECT_TRITON_TRANSFORMS_FUNCTION_TYPE_CONVERSION_H_
3+
#include "mlir/Interfaces/FunctionInterfaces.h"
34
#include "mlir/Transforms/DialectConversion.h"
45

56
namespace mlir::triton {
67

8+
/**
9+
* @brief Rename the aggregated function arguments that were generated by the
10+
* type converter.
11+
*
12+
* @note The callback should return a std::optional<LogicalResult>. If the
13+
* callback returns a std::nullopt, the argument is not renamed and the
14+
* conversion continues. If the callback returns a LogicalResult, the out_suffix
15+
* list is populated with the suffixes for each aggregated argument.
16+
*
17+
* @param delimiter The delimiter to use between the base name and the suffixes.
18+
*/
19+
class FuncArgRenamer {
20+
public:
21+
FuncArgRenamer(const char *delimiter = ".") : delimiter(delimiter) {}
22+
template <typename FnT, typename T = typename llvm::function_traits<
23+
std::decay_t<FnT>>::template arg_t<0>>
24+
void addRenamer(FnT &&callback) {
25+
renamers.emplace_back(wrapCallback<T>(std::forward<FnT>(callback)));
26+
}
27+
LogicalResult apply(Type type, FunctionOpInterface funcOp, int index,
28+
TypeConverter::SignatureConversion &conversion) const;
29+
30+
private:
31+
using RenamerCallbackFn = std::function<std::optional<LogicalResult>(
32+
Type, llvm::SmallVectorImpl<std::string> &)>;
33+
34+
/**
35+
* @brief Wraps a callback of form `std::optional<LogicalResult>(T,
36+
* llvm::SmallVectorImpl<std::string> &)` into a RenamerCallbackFn.
37+
*
38+
* @tparam T The type of the argument.
39+
* @tparam FnT The type of the callback.
40+
* @param callback The callback to wrap.
41+
* @return A RenamerCallbackFn.
42+
*/
43+
template <typename T, typename FnT>
44+
std::enable_if_t<
45+
std::is_invocable_v<FnT, T, llvm::SmallVectorImpl<std::string> &> &&
46+
std::is_base_of_v<Type, T>,
47+
RenamerCallbackFn>
48+
wrapCallback(FnT &&callback) const {
49+
return [callback = std::forward<FnT>(callback)](
50+
Type type, llvm::SmallVectorImpl<std::string> &out_suffix)
51+
-> std::optional<LogicalResult> {
52+
if (auto t = dyn_cast<T>(type)) {
53+
return callback(t, out_suffix);
54+
}
55+
return std::nullopt;
56+
};
57+
}
58+
59+
llvm::SmallVector<RenamerCallbackFn> renamers;
60+
const char *delimiter;
61+
};
62+
763
/**
864
* @brief Provides helper patterns for converting triton function operations
965
* using a type converter.
@@ -12,6 +68,7 @@ namespace mlir::triton {
1268
* tt.call and tt.return.
1369
*/
1470
void populateFunctionTypeConversions(const TypeConverter &converter,
71+
const FuncArgRenamer &renamer,
1572
RewritePatternSet &patterns);
1673

1774
} // namespace mlir::triton

lib/Dialect/Triton/Transforms/FunctionTypeConversion.cpp

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,42 @@
1111

1212
namespace mlir::triton {
1313

14+
LogicalResult
15+
FuncArgRenamer::apply(Type type, FunctionOpInterface funcOp, int index,
16+
TypeConverter::SignatureConversion &conversion) const {
17+
auto mapping = conversion.getInputMapping(index);
18+
if (!mapping)
19+
return success();
20+
21+
for (auto &renamer : llvm::reverse(renamers)) {
22+
llvm::SmallVector<std::string, 8> out_suffix;
23+
if (std::optional<LogicalResult> result = renamer(type, out_suffix)) {
24+
if (failed(*result)) {
25+
return failure();
26+
}
27+
int newIndex = mapping->inputNo;
28+
auto loc = funcOp.getArgument(newIndex).getLoc();
29+
std::string baseName;
30+
if (isa<NameLoc>(loc)) {
31+
baseName = cast<NameLoc>(loc).getName().getValue();
32+
} else {
33+
baseName = "arg_" + std::to_string(index);
34+
}
35+
assert(out_suffix.size() == mapping->size);
36+
for (auto [i, suffix] : llvm::enumerate(out_suffix)) {
37+
if (suffix.empty())
38+
continue;
39+
auto newLoc = NameLoc::get(
40+
StringAttr::get(funcOp.getContext(), baseName + delimiter + suffix),
41+
loc);
42+
funcOp.getArgument(newIndex + i).setLoc(newLoc);
43+
}
44+
return success(); // early return
45+
}
46+
}
47+
return success();
48+
}
49+
1450
namespace {
1551

1652
SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
@@ -102,6 +138,7 @@ convertFuncOpAttrs(FunctionOpInterface funcOp,
102138

103139
LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
104140
const TypeConverter &typeConverter,
141+
const FuncArgRenamer &renamer,
105142
ConversionPatternRewriter &rewriter) {
106143
FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
107144
if (!type)
@@ -129,6 +166,13 @@ LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
129166
}
130167
});
131168

169+
// Apply the renamer to the function signature.
170+
for (auto [i, input] : llvm::enumerate(type.getInputs())) {
171+
if (failed(renamer.apply(input, funcOp, i, result))) {
172+
return failure();
173+
}
174+
}
175+
132176
return success();
133177
}
134178

@@ -139,24 +183,30 @@ struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
139183
FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
140184
MLIRContext *ctx,
141185
const TypeConverter &converter,
186+
const FuncArgRenamer &renamer,
142187
PatternBenefit benefit = 1)
143-
: ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
188+
: ConversionPattern(converter, functionLikeOpName, benefit, ctx),
189+
renamer(renamer) {}
144190

145191
LogicalResult
146192
matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
147193
ConversionPatternRewriter &rewriter) const override {
148194
FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
149-
return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
195+
return convertFuncOpTypes(funcOp, *typeConverter, renamer, rewriter);
150196
}
197+
198+
private:
199+
const FuncArgRenamer &renamer;
151200
};
152201

153202
} // namespace
154203

155204
void populateFunctionTypeConversions(const TypeConverter &converter,
205+
const FuncArgRenamer &renamer,
156206
RewritePatternSet &patterns) {
157207
auto context = patterns.getContext();
158208
patterns.add<FunctionOpInterfaceSignatureConversion>(
159-
triton::FuncOp::getOperationName(), context, converter);
209+
triton::FuncOp::getOperationName(), context, converter, renamer);
160210
patterns.add<CallOpConversion, ReturnOpConversion>(converter, context);
161211
}
162212

lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,11 +580,28 @@ class TritonRewriteTensorDescriptorToPointerPass
580580
return mlir::success();
581581
});
582582

583+
FuncArgRenamer renamer(".");
584+
renamer.addRenamer([](mlir::triton::TensorDescType type,
585+
llvm::SmallVectorImpl<std::string> &out_suffix) {
586+
auto tensorType = type.getSignlessBlockType();
587+
int dims = tensorType.getRank();
588+
out_suffix.push_back("");
589+
for (int i = 0; i < dims; i++) {
590+
out_suffix.push_back("shape." + std::to_string(i));
591+
}
592+
for (int i = 0; i < dims; i++) {
593+
out_suffix.push_back("stride." + std::to_string(i));
594+
}
595+
out_suffix.push_back("padding");
596+
out_suffix.push_back("roundF32ToTF32");
597+
return success();
598+
});
599+
583600
mlir::RewritePatternSet patterns(op->getContext());
584601

585602
// Populate conversion patterns to handle loops, function calls, and arith
586603
// ops.
587-
triton::populateFunctionTypeConversions(converter, patterns);
604+
triton::populateFunctionTypeConversions(converter, renamer, patterns);
588605
mlir::scf::populateSCFStructuralTypeConversions(converter, patterns);
589606
triton::populateArithTypeConversions(converter, patterns);
590607

test/Triton/rewrite-tensor-descriptor-to-pointer.mlir

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s --triton-rewrite-tensor-descriptor-to-pointer --canonicalize --cse --split-input-file | FileCheck %s --implicit-check-not \!tt.tensordesc
1+
// RUN: triton-opt %s --triton-rewrite-tensor-descriptor-to-pointer --canonicalize --cse --mlir-print-debuginfo --split-input-file | FileCheck %s --implicit-check-not \!tt.tensordesc
22

33
module {
44
tt.func public @load(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: i32) -> (tensor<128x128xf32>) {
@@ -109,8 +109,9 @@ module {
109109

110110
// -----
111111

112+
#loc2 = loc("rewrite-tensor-descriptor-to-pointer.mlir":147:28)
112113
module {
113-
tt.func public @callee(%tensordesc: !tt.tensordesc<tensor<128x128xf32>>) -> !tt.tensordesc<tensor<128x128xf32>> {
114+
tt.func public @callee(%tensordesc: !tt.tensordesc<tensor<128x128xf32>> loc("tensordesc"(#loc2))) -> !tt.tensordesc<tensor<128x128xf32>> {
114115
tt.return %tensordesc : !tt.tensordesc<tensor<128x128xf32>>
115116
}
116117

@@ -126,12 +127,19 @@ module {
126127

127128
// CHECK-LABEL: @callee
128129
// CHECK-SAME: %[[PTR:[^:]*]]
130+
// CHECK-SAME: loc("tensordesc"(#loc{{[^,]*}}))
129131
// CHECK-SAME: %[[SHAPE0:[^:]*]]
132+
// CHECK-SAME: loc("tensordesc.shape.0"(#loc{{[^,]*}}))
130133
// CHECK-SAME: %[[SHAPE1:[^:]*]]
134+
// CHECK-SAME: loc("tensordesc.shape.1"(#loc{{[^,]*}}))
131135
// CHECK-SAME: %[[STRIDE0:[^:]*]]
136+
// CHECK-SAME: loc("tensordesc.stride.0"(#loc{{[^,]*}}))
132137
// CHECK-SAME: %[[STRIDE1:[^:]*]]
138+
// CHECK-SAME: loc("tensordesc.stride.1"(#loc{{[^,]*}}))
133139
// CHECK-SAME: %[[PAD:[^:]*]]
140+
// CHECK-SAME: loc("tensordesc.padding"(#loc{{[^,]*}}))
134141
// CHECK-SAME: %[[ROUND:[^:]*]]
142+
// CHECK-SAME: loc("tensordesc.roundF32ToTF32"(#loc{{[^,]*}}))
135143
// CHECK-NEXT: tt.return %[[PTR]], %[[SHAPE0]], %[[SHAPE1]], %[[STRIDE0]], %[[STRIDE1]], %[[PAD]], %[[ROUND]]
136144

137145
// CHECK-LABEL: @caller
@@ -150,4 +158,4 @@ module {
150158
}
151159

152160
// CHECK-LABEL: @arg_attr
153-
// CHECK-SAME: %arg7: i32 {tt.divisibility = 16 : i32}) {
161+
// CHECK-SAME: %arg7: i32 {tt.divisibility = 16 : i32} loc({{.*}})) {

0 commit comments

Comments
 (0)