Skip to content

Commit 105ecb0

Browse files
i3wanna2Galaxy1458zhzhcookieikushareyoungyoung01
authored
[HINT] add hint for dsl and ast (flagos-ai#13)
rename hint,add a simple test (only DSL code, excluding compilation and MLIR tests, which are expected to be done later), fix a load op issue --------- Co-authored-by: i3wanna2 <2535184404@qq.com> Co-authored-by: zhengyang <zhengyang@baai.ac.cn> Co-authored-by: ikushare <2697533527@qq.com> Co-authored-by: yangcm <3259543738@qq.com>
1 parent b245a06 commit 105ecb0

File tree

11 files changed

+256
-51
lines changed

11 files changed

+256
-51
lines changed

include/triton/Dialect/Triton/IR/TritonAttrDefs.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITON_ATTR_DEFS
33

44
include "mlir/IR/EnumAttr.td"
5+
include "mlir/IR/AttrTypeBase.td"
56

67
// Attributes for LoadOp and StoreOp
78
def TT_CacheModifierAttr : I32EnumAttr<

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
1414
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
1515
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
1616
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
17+
include "mlir/IR/BuiltinAttributes.td"
1718

1819

1920
//
@@ -248,13 +249,33 @@ def TT_LoadOp : TT_Op<"load", [
248249
OptionalAttr<TT_PaddingOptionAttr>:$padding,
249250
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
250251
DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict,
251-
DefaultValuedAttr<BoolAttr, "false">:$isVolatile
252+
DefaultValuedAttr<BoolAttr, "false">:$isVolatile,
253+
// TODO: now flagtree_hints is string, default value of an empty string (""), needed redesign
254+
DefaultValuedAttr<StrAttr, "\"\"">:$flagtree_hints
252255
);
253256

254257
let results = (outs TT_Type:$result);
255258

256259
let builders = [
257260
// A tensor of pointers or a pointer to a scalar
261+
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
262+
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "mlir::StringAttr":$flagtree_hints)>,
263+
// A tensor pointer with boundary check and padding
264+
OpBuilder<(ins "Value":$ptr, "ArrayRef<int32_t>":$boundaryCheck,
265+
"std::optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
266+
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "mlir::StringAttr":$flagtree_hints)>,
267+
// A tensor of pointers or a pointer to a scalar with mask
268+
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
269+
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "mlir::StringAttr":$flagtree_hints)>,
270+
// A tensor of pointers or a pointer to a scalar with mask and other
271+
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
272+
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "mlir::StringAttr":$flagtree_hints)>,
273+
// A utility function to build the operation with all attributes
274+
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other,
275+
"ArrayRef<int32_t>":$boundaryCheck,
276+
"std::optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
277+
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "mlir::StringAttr":$flagtree_hints)>,
278+
// A tensor of pointers or a pointer to a scalar
258279
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
259280
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
260281
// A tensor pointer with boundary check and padding

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
4545
cache, evict, isVolatile);
4646
}
4747

48+
// implementatio with flagtree_hints
49+
void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
50+
CacheModifier cache, EvictionPolicy evict, bool isVolatile,
51+
mlir::StringAttr flagtree_hints) {
52+
LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{},
53+
/*boundaryCheck=*/ArrayRef<int32_t>{}, /*padding=*/std::nullopt,
54+
cache, evict, isVolatile, flagtree_hints);
55+
}
56+
4857
void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
4958
ArrayRef<int32_t> boundaryCheck,
5059
std::optional<PaddingOption> padding, CacheModifier cache,
@@ -53,6 +62,16 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
5362
padding, cache, evict, isVolatile);
5463
}
5564

65+
// implementatio with flagtree_hints
66+
void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
67+
ArrayRef<int32_t> boundaryCheck,
68+
std::optional<PaddingOption> padding, CacheModifier cache,
69+
EvictionPolicy evict, bool isVolatile,
70+
mlir::StringAttr flagtree_hints) {
71+
LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, boundaryCheck,
72+
padding, cache, evict, isVolatile, flagtree_hints);
73+
}
74+
5675
void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
5776
Value mask, CacheModifier cache, EvictionPolicy evict,
5877
bool isVolatile) {
@@ -61,6 +80,16 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
6180
/*padding=*/std::nullopt, cache, evict, isVolatile);
6281
}
6382

83+
// implementatio with flagtree_hints
84+
void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
85+
Value mask, CacheModifier cache, EvictionPolicy evict,
86+
bool isVolatile, mlir::StringAttr flagtree_hints) {
87+
LoadOp::build(builder, state, ptr, mask, /*other=*/{},
88+
/*boundaryCheck=*/ArrayRef<int32_t>{},
89+
/*padding=*/std::nullopt, cache, evict, isVolatile,
90+
flagtree_hints);
91+
}
92+
6493
void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
6594
Value mask, Value other, CacheModifier cache,
6695
EvictionPolicy evict, bool isVolatile) {
@@ -69,6 +98,17 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
6998
/*padding=*/std::nullopt, cache, evict, isVolatile);
7099
}
71100

101+
// implementatio with flagtree_hints
102+
void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
103+
Value mask, Value other, CacheModifier cache,
104+
EvictionPolicy evict, bool isVolatile,
105+
mlir::StringAttr flagtree_hints) {
106+
LoadOp::build(builder, state, ptr, mask, other,
107+
/*boundaryCheck=*/ArrayRef<int32_t>{},
108+
/*padding=*/std::nullopt, cache, evict, isVolatile,
109+
flagtree_hints);
110+
}
111+
72112
void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
73113
Value mask, Value other, ArrayRef<int32_t> boundaryCheck,
74114
std::optional<PaddingOption> padding, CacheModifier cache,
@@ -82,6 +122,21 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
82122
evict, isVolatile);
83123
}
84124

125+
// implementatio with flagtree_hints
126+
void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
127+
Value mask, Value other, ArrayRef<int32_t> boundaryCheck,
128+
std::optional<PaddingOption> padding, CacheModifier cache,
129+
EvictionPolicy evict, bool isVolatile,
130+
mlir::StringAttr flagtree_hints) {
131+
auto paddingAttr =
132+
padding.has_value()
133+
? PaddingOptionAttr::get(builder.getContext(), padding.value())
134+
: PaddingOptionAttr();
135+
LoadOp::build(builder, state, ptr, mask, other,
136+
builder.getDenseI32ArrayAttr(boundaryCheck), paddingAttr, cache,
137+
evict, isVolatile, flagtree_hints);
138+
}
139+
85140
// load(ptr, splat(1), ...) -> load(ptr, ...)
86141
// load(ptr, splat(0), other, ...) -> other
87142
struct CanonicalizeMaskedLoadPattern : public OpRewritePattern<LoadOp> {

lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,8 @@ class RewriteTensorPointerPass
312312
if (auto loadOp = dyn_cast<triton::LoadOp>(op)) {
313313
auto newResult = builder.create<triton::LoadOp>(
314314
loadOp.getLoc(), newPtr, newMask, newOther, loadOp.getCache(),
315-
loadOp.getEvict(), loadOp.getIsVolatile());
315+
loadOp.getEvict(), loadOp.getIsVolatile(),
316+
loadOp.getFlagtreeHintsAttr());
316317
op->getResult(0).replaceAllUsesWith(newResult);
317318
if (op->getAttr("async_task_id"))
318319
newResult->setAttr("async_task_id", op->getAttr("async_task_id"));

python/src/ir.cc

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,9 +1360,14 @@ void init_triton_ir(py::module &&m) {
13601360
// Input/Output
13611361
.def("create_load",
13621362
[](TritonOpBuilder &self, Value &ptrs, CacheModifier cacheModifier,
1363-
EvictionPolicy evictionPolicy, bool isVolatile) -> Value {
1363+
EvictionPolicy evictionPolicy, bool isVolatile,
1364+
std::optional<std::string> flagtree_hints) -> Value {
1365+
auto myHintsAttr =
1366+
flagtree_hints
1367+
? mlir::StringAttr::get(self.getContext(), *flagtree_hints)
1368+
: mlir::StringAttr::get(self.getContext(), "");
13641369
return self.create<LoadOp>(ptrs, cacheModifier, evictionPolicy,
1365-
isVolatile);
1370+
isVolatile, myHintsAttr);
13661371
})
13671372
.def("create_store",
13681373
[](TritonOpBuilder &self, Value &ptrs, Value &value,
@@ -1375,10 +1380,16 @@ void init_triton_ir(py::module &&m) {
13751380
std::vector<int32_t> &boundaryCheck,
13761381
std::optional<PaddingOption> paddingOption,
13771382
CacheModifier cacheModifier, EvictionPolicy evictionPolicy,
1378-
bool isVolatile) -> Value {
1383+
bool isVolatile,
1384+
std::optional<std::string> flagtree_hints) -> Value {
1385+
auto myHintsAttr =
1386+
flagtree_hints
1387+
? mlir::StringAttr::get(self.getContext(), *flagtree_hints)
1388+
: mlir::StringAttr::get(self.getContext(), "");
1389+
13791390
return self.create<LoadOp>(ptr, boundaryCheck, paddingOption,
13801391
cacheModifier, evictionPolicy,
1381-
isVolatile);
1392+
isVolatile, myHintsAttr);
13821393
})
13831394
.def("create_tensor_pointer_store",
13841395
[](TritonOpBuilder &self, Value &ptr, Value &val,
@@ -1390,10 +1401,15 @@ void init_triton_ir(py::module &&m) {
13901401
.def("create_masked_load",
13911402
[](TritonOpBuilder &self, Value &ptrs, Value &mask,
13921403
std::optional<Value> &other, CacheModifier cacheModifier,
1393-
EvictionPolicy evictionPolicy, bool isVolatile) -> Value {
1404+
EvictionPolicy evictionPolicy, bool isVolatile,
1405+
std::optional<std::string> flagtree_hints) -> Value {
1406+
auto myHintsAttr =
1407+
flagtree_hints
1408+
? mlir::StringAttr::get(self.getContext(), *flagtree_hints)
1409+
: mlir::StringAttr::get(self.getContext(), "");
13941410
return self.create<LoadOp>(ptrs, mask, other.value_or(Value()),
13951411
cacheModifier, evictionPolicy,
1396-
isVolatile);
1412+
isVolatile, myHintsAttr);
13971413
})
13981414
.def("create_masked_store",
13991415
[](TritonOpBuilder &self, Value &ptrs, Value &val, Value &mask,

python/triton/compiler/code_generator.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,23 +1229,45 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
12291229
return next(unflatten_ir_values(handles, [callee_ret_type]))
12301230

12311231
def visit_Call(self, node):
1232+
# 1. Get the called function object
12321233
fn = _unwrap_if_constexpr(self.visit(node.func))
1234+
1235+
# 2. Check if it's a statically implemented function
12331236
static_implementation = self.statically_implemented_functions.get(fn)
12341237
if static_implementation is not None:
12351238
return static_implementation(self, node)
12361239

1240+
# 3. Process keyword and positional arguments
12371241
kws = dict(self.visit(keyword) for keyword in node.keywords)
12381242
args = [self.visit(arg) for arg in node.args]
12391243
args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
1244+
1245+
# 4. Get current line number and hints
1246+
line_num = node.lineno
1247+
function_def = self.jit_fn.parse()
1248+
line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {})
1249+
flagtree_hints = line_flagtree_hints.get(line_num)
1250+
1251+
# 5. Handle JIT function calls
12401252
if isinstance(fn, JITFunction):
12411253
_check_fn_args(node, fn, args)
12421254
return self.call_JitFunction(fn, args, kws)
1255+
1256+
# 6. Handle built-in functions or calls with special context
12431257
if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn):
12441258
extra_kwargs = {"_builder": self.builder}
12451259
sig = inspect.signature(fn)
12461260
if '_generator' in sig.parameters:
12471261
extra_kwargs['_generator'] = self
12481262
try:
1263+
# Special handling for tl.load with hints
1264+
if fn.__name__ == "load" and flagtree_hints is not None:
1265+
print(f"tl.load at line {line_num} has annotation {flagtree_hints}")
1266+
if 'flagtree_hints' not in kws:
1267+
kws['flagtree_hints'] = ""
1268+
if flagtree_hints not in kws['flagtree_hints']:
1269+
kws['flagtree_hints'] = flagtree_hints
1270+
12491271
ret = fn(*args, **extra_kwargs, **kws)
12501272
# builtin functions return plain tuples for readability
12511273
if isinstance(ret, tuple):
@@ -1260,6 +1282,7 @@ def visit_Call(self, node):
12601282
# be in core.py.
12611283
raise CompilationError(self.jit_fn.src, node, None) from e
12621284

1285+
# 7. Handle calls from built-in namespace
12631286
if fn in self.builtin_namespace.values():
12641287
args = map(_unwrap_if_constexpr, args)
12651288
ret = fn(*args, **kws)

python/triton/language/core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1857,7 +1857,7 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None,
18571857

18581858
@builtin
18591859
def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="",
1860-
volatile=False, _builder=None):
1860+
volatile=False, flagtree_hints=None, _builder=None):
18611861
"""
18621862
Return a tensor of data whose values are loaded from memory at location defined by `pointer`:
18631863
@@ -1911,8 +1911,9 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c
19111911
cache_modifier = _constexpr_to_value(cache_modifier)
19121912
eviction_policy = _constexpr_to_value(eviction_policy)
19131913
volatile = _constexpr_to_value(volatile)
1914+
flagtree_hints = _constexpr_to_value(flagtree_hints)
19141915
return semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy,
1915-
volatile, _builder)
1916+
volatile, flagtree_hints, _builder)
19161917

19171918

19181919
@builtin

python/triton/language/semantic.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,7 +1047,8 @@ def _canonicalize_boundary_check(boundary_check, block_shape):
10471047
return ()
10481048

10491049

1050-
def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder):
1050+
def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, flagtree_hints,
1051+
builder):
10511052
# Load by a block pointer: `pointer_type<block_type<>>`
10521053
# Block pointer can not have `mask` and `other` arguments
10531054
if mask is not None or other is not None:
@@ -1066,10 +1067,11 @@ def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, evicti
10661067

10671068
# Build IR
10681069
return tl.tensor(
1069-
builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty)
1070+
builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile,
1071+
flagtree_hints), dst_ty)
10701072

10711073

1072-
def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder):
1074+
def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, flagtree_hints, builder):
10731075
# Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
10741076
if not ptr.type.scalar.is_ptr():
10751077
raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`")
@@ -1121,18 +1123,18 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_
11211123

11221124
# Build IR
11231125
if mask is None:
1124-
ret = tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty)
1126+
ret = tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile, flagtree_hints), dst_ty)
11251127
else:
11261128
ret = tl.tensor(
11271129
builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction,
1128-
is_volatile), dst_ty)
1130+
is_volatile, flagtree_hints), dst_ty)
11291131
if is_bool:
11301132
ret = cast(ret, tl.int1, builder)
11311133
return ret
11321134

11331135

11341136
def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check: Tuple,
1135-
padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool,
1137+
padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool, flagtree_hints: str,
11361138
builder: ir.builder) -> tl.tensor:
11371139
# Cache, eviction and padding options
11381140
cache = _str_to_load_cache_modifier(cache_modifier)
@@ -1141,10 +1143,12 @@ def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor],
11411143

11421144
if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
11431145
# Load by a block pointer: `pointer_type<block_type<>>`
1144-
return _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder)
1146+
return _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile,
1147+
flagtree_hints, builder)
11451148
else:
11461149
# Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
1147-
return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder)
1150+
return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, flagtree_hints,
1151+
builder)
11481152

11491153

11501154
def reinterpret_tensor_descriptor(desc_ptr: tl.tensor, block_ty: tl.block_type, builder: ir.builder):

python/triton/runtime/jit.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from ..runtime.driver import driver
1313
from types import ModuleType
1414
from .._utils import find_paths_if, get_iterable_path
15+
import tokenize
16+
from io import StringIO
1517

1618
TRITON_MODULE = __name__[:-len(".runtime.jit")]
1719

@@ -703,10 +705,26 @@ def preload(self, specialization_data):
703705
# the user might want to monkey-patch self.src dynamically.
704706
# Our unit tests do this, for example.
705707
def parse(self):
708+
# Maps line numbers to comment hints
709+
line_flagtree_hints = {}
710+
code_str = self.src
711+
g = tokenize.generate_tokens(StringIO(code_str).readline)
712+
for tok_type, tok_text, start, end, _ in g:
713+
if tok_type == tokenize.COMMENT:
714+
comment = tok_text.replace(" ", "").strip()
715+
if comment.startswith('#@hint:'):
716+
flagtree_hints = comment[len('#@hint:'):].strip()
717+
# Record the line number of the comment
718+
line_num = start[0]
719+
line_flagtree_hints[line_num] = flagtree_hints
720+
706721
tree = ast.parse(self.src)
707722
assert isinstance(tree, ast.Module)
708723
assert len(tree.body) == 1
709724
assert isinstance(tree.body[0], ast.FunctionDef)
725+
726+
# Attach the line number to comment mapping to the function definition node
727+
tree.body[0].line_flagtree_hints = line_flagtree_hints
710728
return tree
711729

712730
def __call__(self, *args, **kwargs):

0 commit comments

Comments
 (0)