Skip to content

Commit 2e1ea33

Browse files
authored
Fix issue with fp64 constants (#506)
1 parent 2cae9ed commit 2e1ea33

File tree

4 files changed

+62
-4
lines changed

4 files changed

+62
-4
lines changed

helion/_compiler/device_function.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import sympy
1616
import torch
17-
from torch._inductor.codegen.triton import texpr
17+
from torch._inductor.codegen.triton import TritonPrinter
1818
from torch.fx.graph import _Namespace
1919

2020
from .._compat import get_tensor_descriptor_fn_name
@@ -599,3 +599,22 @@ def current() -> DeviceFunction:
599599
return tls.functions[-1]
600600
except (AttributeError, IndexError):
601601
raise NoCurrentFunction from None
602+
603+
604+
class HelionTritonPrinter(TritonPrinter):
605+
"""Custom Triton printer that avoids wrapping float literals in tl.full().
606+
607+
Inductor's default TritonPrinter prints SymPy Float as a 0-D Triton value
608+
via tl.full([], <val>, tl.float64). We override this to emit the raw numeric
609+
literal, letting downstream type promotion and casts handle dtype.
610+
"""
611+
612+
def _print_Float(self, expr: sympy.Expr) -> str:
613+
return str(expr)
614+
615+
def _print_ToFloat(self, expr: sympy.Expr) -> str:
616+
return f"{expr} + 0.0"
617+
618+
619+
def texpr(expr: sympy.Expr) -> str:
620+
return HelionTritonPrinter().doprint(expr)

test/test_broadcasting.expected

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,24 @@ def fn(a, b, *, _launcher=_default_launcher):
246246
_BLOCK_SIZE_1 = 16
247247
_launcher(_helion_fn, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out.stride(0), out.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
248248
return out
249+
250+
--- assertExpectedJournal(TestBroadcasting.test_python_float_promotion)
251+
from __future__ import annotations
252+
253+
import triton
254+
import triton.language as tl
255+
from helion.runtime import default_launcher as _default_launcher
256+
257+
@triton.jit
258+
def _helion_fn(a, a_size_0, a_stride_0, beta, _BLOCK_SIZE_0: tl.constexpr):
259+
pid_0 = tl.program_id(0)
260+
offset_0 = pid_0 * _BLOCK_SIZE_0
261+
b = tl.load(tl.make_block_ptr(a, [a_size_0], [a_stride_0], [offset_0], [_BLOCK_SIZE_0], [0]), boundary_check=[0], padding_option='zero')
262+
sub = 1.0 + -1 * beta
263+
v_0 = b * sub
264+
tl.store(tl.make_block_ptr(a, [a_size_0], [a_stride_0], [offset_0], [_BLOCK_SIZE_0], [0]), v_0, boundary_check=[0])
265+
266+
def fn(a, beta, *, _launcher=_default_launcher):
267+
_BLOCK_SIZE_0 = 16
268+
_launcher(_helion_fn, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, a.size(0), a.stride(0), beta, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
269+
return a

test/test_broadcasting.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,26 @@ def fn(a, b):
106106
torch.testing.assert_close(out, sum(args))
107107
self.assertExpectedJournal(code)
108108

109+
def test_python_float_promotion(self):
110+
# Repro for https://github.com/pytorch/helion/issues/493
111+
# Python floats should follow PyTorch type promotion (no unintended fp64 upcast)
112+
@helion.kernel(config={"block_size": 16, "indexing": "block_ptr"})
113+
def fn(a, beta):
114+
for tile0 in hl.tile(a.shape[0]):
115+
b = a[tile0]
116+
a[tile0] = (1 - beta) * b
117+
return a
118+
119+
a = torch.randn(1024, device=DEVICE)
120+
beta = 1.5
121+
args = (a, beta)
122+
123+
# Expected behavior matches PyTorch promotion rules on tensors
124+
expected = (1 - beta) * a
125+
code, out = code_and_output(fn, args)
126+
torch.testing.assert_close(out, expected)
127+
self.assertExpectedJournal(code)
128+
109129

110130
if __name__ == "__main__":
111131
unittest.main()

test/test_specialize.expected

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,7 @@ def _helion_fn(x, out, x_size_0, out_stride_0, out_stride_1, x_stride_0, x_strid
7979
mask_0 = indices_0 < x_size_0
8080
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
8181
mask_1 = indices_1 < 500
82-
sym_float = tl.full([], 512.0, tl.float64)
83-
truediv = tl.full([], 0.001953125, tl.float64)
84-
acc = tl.full([_BLOCK_SIZE_0, 512], truediv, tl.float32)
82+
acc = tl.full([_BLOCK_SIZE_0, 512], 0.001953125, tl.float32)
8583
acc2 = tl.full([512, 512], 1.0, tl.float32)
8684
_mask_to = tl.where(tl.broadcast_to(mask_0[:, None], [_BLOCK_SIZE_0, 512]), acc, 0)
8785
acc_1 = tl.dot(_mask_to, acc2, input_precision='tf32')

0 commit comments

Comments
 (0)