Skip to content

Commit cee62bb

Browse files
authored
triton_kernels: use tl.clamp whenever possible (#8728)
In many cases `tl.clamp` can be compiled to a single instruction rather than two instructions for `tl.minimum` + `tl.maximum`: - Output saturation instruction modifier to clamp output to [0, 1] - `min.xorsign.abs.f32` instructions on Hopper+ - `V_MED3_F32` instructions on AMD # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `expected to be bit-exact`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 3ebfe54 commit cee62bb

2 files changed

Lines changed: 4 additions & 5 deletions

File tree

python/triton_kernels/triton_kernels/numerics_details/flexpoint.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,7 @@ def flex_to_float(x, scale_ptr):
115115

116116
@triton.jit
117117
def clip(x, limit):
118-
res = tl.minimum(x, limit)
119-
res = tl.maximum(-limit, res)
120-
return res
118+
return tl.clamp(x, -limit, limit)
121119

122120

123121
@triton.jit

python/triton_kernels/triton_kernels/swiglu_details/_swiglu.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66
@triton.jit
77
def clip(x, limit, clip_lower: tl.constexpr):
8-
res = tl.minimum(x, limit)
98
if clip_lower:
10-
res = tl.maximum(-limit, res)
9+
res = tl.clamp(x, -limit, limit)
10+
else:
11+
res = tl.minimum(x, limit)
1112
return res
1213

1314

0 commit comments

Comments
 (0)