Skip to content

Commit 3322ca9

Browse files
authored
Add torch compile unit test to helion (#782)
1 parent 46876f2 commit 3322ca9

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

helion/runtime/kernel.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,9 +385,7 @@ def configs(self) -> list[Config]:
385385

386386
def format_kernel_decorator(self, config: Config, settings: Settings) -> str:
387387
"""Return the @helion.kernel decorator snippet capturing configs and settings that influence Triton code generation."""
388-
return (
389-
f"@helion.kernel(config={config!r}, static_shapes={settings.static_shapes})"
390-
)
388+
return f"@helion.kernel(config={config.__repr__()}, static_shapes={settings.static_shapes})"
391389

392390
def to_triton_code(
393391
self, config: ConfigLike | None = None, emit_repro_caller: bool = False

test/test_torch_compile.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from __future__ import annotations
2+
3+
import unittest
4+
5+
import torch
6+
7+
import helion
8+
from helion._testing import DEVICE
9+
from helion._testing import RefEagerTestBase
10+
from helion._testing import TestCase
11+
from helion._testing import skipIfRefEager
12+
import helion.language as hl
13+
14+
15+
class TestTorchCompile(RefEagerTestBase, TestCase):
16+
@skipIfRefEager("does not work with ref eager")
17+
def test_add_kernel(self):
18+
@helion.kernel(config=helion.Config(block_sizes=[1, 2]))
19+
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
20+
out = torch.empty_like(x)
21+
for tile in hl.tile(out.size()):
22+
out[tile] = x[tile] + y[tile]
23+
return out
24+
25+
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
26+
return add(x, y)
27+
28+
x = torch.randn(4, 8, device=DEVICE, dtype=torch.float16)
29+
y = torch.randn(4, 8, device=DEVICE, dtype=torch.float16)
30+
31+
out = add(x, y)
32+
compiled_add = torch.compile(f, fullgraph=True, backend="inductor")
33+
compiled_out = compiled_add(x, y)
34+
35+
torch.testing.assert_close(out, x + y)
36+
torch.testing.assert_close(compiled_out, x + y)
37+
38+
39+
if __name__ == "__main__":
40+
unittest.main()

0 commit comments

Comments
 (0)