Skip to content

Commit f97f66a

Browse files
authored
Revert "[language] Skip f16 to f32 promotion in max/min reductions" (#9921)
Reverting as it breaks BC and needs some time to update Reverts #9903
1 parent 8956d90 commit f97f66a

2 files changed

Lines changed: 3 additions & 30 deletions

File tree

python/test/unit/language/test_compile_only.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -220,26 +220,3 @@ def fp8_convert(src, dst):
220220
src = ASTSource(fn=fp8_convert, signature={"src": "*fp32", "dst": "*fp8e5"}, constexprs={})
221221
triton.compile(src, target=GPUTarget("cuda", 90, 32))
222222
triton.compile(src, target=GPUTarget("cuda", 80, 32))
223-
224-
225-
def test_f16_min_max_no_promotion():
226-
"""f16 should not get promoted to f32 in min/max reductions."""
227-
228-
@triton.jit
229-
def reduce_min(src, dst):
230-
idx = tl.arange(0, 64)
231-
x = tl.load(src + idx)
232-
tl.store(dst, tl.min(x, axis=0))
233-
234-
@triton.jit
235-
def reduce_max(src, dst):
236-
idx = tl.arange(0, 64)
237-
x = tl.load(src + idx)
238-
tl.store(dst, tl.max(x, axis=0))
239-
240-
targets = [GPUTarget("cuda", 90, 32), GPUTarget("hip", "gfx942", 64)]
241-
for target in targets:
242-
for kernel in [reduce_min, reduce_max]:
243-
f16 = triton.compile(ASTSource(fn=kernel, signature={"src": "*fp16", "dst": "*fp16"}, constexprs={}),
244-
target=target)
245-
assert "arith.extf" not in f16.asm["ttir"], "f16 should not get promoted to f32 in min/max reductions"

python/triton/language/standard.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,7 @@ def max(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
184184
else:
185185
if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32):
186186
if core.constexpr(input.dtype.is_floating()):
187-
# Do not promote f16 to f32 as it has native hardware support
188-
if not core.constexpr(input.dtype == core.float16):
189-
input = input.to(core.float32)
187+
input = input.to(core.float32)
190188
else:
191189
assert input.dtype.is_int(), "Expecting input to be integer type"
192190
input = input.to(core.int32)
@@ -243,11 +241,9 @@ def min(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
243241
else:
244242
return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_fast, keep_dims=keep_dims)
245243
else:
246-
if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32):
244+
if core.constexpr(input.dtype.primitive_bitwidth) < 32:
247245
if core.constexpr(input.dtype.is_floating()):
248-
# Do not promote f16 to f32 as it has native hardware support
249-
if not core.constexpr(input.dtype == core.float16):
250-
input = input.to(core.float32)
246+
input = input.to(core.float32)
251247
else:
252248
assert input.dtype.is_int(), "Expecting input to be integer type"
253249
input = input.to(core.int32)

0 commit comments

Comments
 (0)