Skip to content

Commit 788795b

Browse files
authored
[TEST_DEBUG] Enable test device assert (#5395)
Comes from #2755. This PR is a temporary fix leveraging subprocess for catching the abort signal, enabling test_device_assert until pytorch/pytorch#142135 is resolved.
1 parent d478b30 commit 788795b

File tree

9 files changed

+81
-27
lines changed

9 files changed

+81
-27
lines changed

python/test/unit/test_debug.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import torch
33
import triton.language as tl
44
import triton
5+
import sys
6+
import subprocess
7+
import os
58

69

710
@pytest.mark.parametrize('cond', [True, False])
@@ -10,29 +13,35 @@
1013
@pytest.mark.parametrize('env_var', [True, False])
1114
@pytest.mark.parametrize('jit_flag', [True, False])
1215
@pytest.mark.forked
13-
def test_device_assert(monkeypatch, cond, mask, opt_flag, env_var, jit_flag, device):
14-
monkeypatch.setenv("TRITON_DEBUG", str(int(env_var)))
15-
triton.knobs.refresh_knobs()
16-
torch.zeros([1], dtype=torch.int32, device=device)
17-
18-
@triton.jit(debug=jit_flag)
19-
def _kernel(COND: tl.constexpr, MASK: tl.constexpr):
20-
tl.device_assert(COND, 'test', mask=MASK)
16+
def test_device_assert(cond, mask, opt_flag, env_var, jit_flag, device):
17+
"""Temporary subprocess solution due to:
18+
https://github.com/pytorch/pytorch/issues/142135"""
2119

2220
is_debug = env_var or (opt_flag if opt_flag is not None else jit_flag)
2321

24-
kwargs = {}
25-
if opt_flag is not None:
26-
kwargs["debug"] = opt_flag
27-
28-
if not cond and is_debug and mask is not False:
29-
with pytest.raises(RuntimeError):
30-
_kernel[(1, )](cond, mask, **kwargs)
31-
getattr(torch, device).synchronize()
32-
return
33-
34-
_kernel[(1, )](cond, mask, **kwargs)
35-
getattr(torch, device).synchronize()
22+
should_fail = not cond and is_debug and mask is not False
23+
kernel_file = os.path.join(os.path.dirname(__file__), "test_debug_kernels.py")
24+
mask_str = "None" if mask is None else str(mask)
25+
opt_flag_str = "None" if opt_flag is None else str(opt_flag)
26+
27+
env = os.environ.copy()
28+
env["TRITON_DEBUG"] = str(int(env_var))
29+
30+
result = subprocess.run(
31+
[sys.executable, kernel_file, "device_assert",
32+
str(cond), mask_str, opt_flag_str,
33+
str(jit_flag), device], capture_output=True, text=True, env=env)
34+
35+
if should_fail:
36+
if device == 'xpu':
37+
assert result.returncode == -6, (f"Expected SIGABRT but got exit code {result.returncode}. "
38+
f"stdout: {result.stdout}, stderr: {result.stderr}")
39+
else:
40+
assert result.returncode == 1, (f"Expected runtime error but got unexpected exit code {result.returncode}. "
41+
f"stdout: {result.stdout}, stderr: {result.stderr}")
42+
else:
43+
assert result.returncode == 0, (f"Expected success but got unexpected exit code {result.returncode}. "
44+
f"stdout: {result.stdout}, stderr: {result.stderr}")
3645

3746

3847
def test_device_assert_barrier(monkeypatch, device):
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""
2+
Helper module containing Triton kernels for test_debug.py.
3+
These kernels are separated so they can be called from subprocesses.
4+
"""
5+
import torch
6+
import triton
7+
import triton.language as tl
8+
import sys
9+
10+
11+
def run_device_assert_kernel(cond, mask, opt_flag, jit_flag, device):
12+
13+
@triton.jit(debug=jit_flag)
14+
def _kernel(COND: tl.constexpr, MASK: tl.constexpr):
15+
tl.device_assert(COND, 'test', mask=MASK)
16+
17+
kwargs = {}
18+
if opt_flag is not None:
19+
kwargs["debug"] = opt_flag
20+
21+
try:
22+
_kernel[(1, )](cond, mask, **kwargs)
23+
getattr(torch, device).synchronize()
24+
return 0
25+
except RuntimeError:
26+
return 1
27+
except Exception as e:
28+
print(f"Unexpected error: {type(e).__name__}: {e}")
29+
return 2
30+
31+
32+
if __name__ == "__main__":
33+
34+
def parse_bool_or_none(arg_str):
35+
if arg_str == "None":
36+
return None
37+
return arg_str == "True"
38+
39+
test_type = sys.argv[1]
40+
if test_type == "device_assert":
41+
cond = sys.argv[2] == "True"
42+
mask = parse_bool_or_none(sys.argv[3])
43+
opt_flag = parse_bool_or_none(sys.argv[4])
44+
jit_flag = sys.argv[5] == "True"
45+
device = sys.argv[6]
46+
triton.knobs.refresh_knobs()
47+
exit_code = run_device_assert_kernel(cond, mask, opt_flag, jit_flag, device)
48+
sys.exit(exit_code)
49+
50+
else:
51+
print(f"Unknown test type: {test_type}")
52+
sys.exit(3)

scripts/skiplist/a770/debug.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2755
2-
python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp
32
python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp
43
python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp
54
python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True]

scripts/skiplist/arl-h/debug.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2755
2-
python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp
32
python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp
43
python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp
54
python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True]

scripts/skiplist/arl-s/debug.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2755
2-
python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp
32
python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp
43
python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp
54
python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True]

scripts/skiplist/default/debug.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2755
2-
python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp
32
python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp
43
python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp
54
python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True]

scripts/skiplist/lts/debug.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2755
2-
python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp
32
python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp
43
python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp
54
python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True]

scripts/skiplist/mtl/debug.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2755
2-
python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp
32
python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp
43
python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp
54
python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True]

scripts/skiplist/xe2/debug.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2755
2-
python/test/unit/test_debug.py::test_device_assert[r"^(False|True)-(False|True)-True-(True|None)-False$|^(False|True)-True-None-(True|None)-False$|^(False|True)-True-False-(True|None)-False$|^True-False-None-(True|None)-False$"]@regexp
32
python/test/unit/test_debug.py::test_sanitize_int_add_overflow[r".*True-True$"]@regexp
43
python/test/unit/test_debug.py::test_sanitize_int_mul_overflow[r".*True-True$"]@regexp
54
python/test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True]

0 commit comments

Comments
 (0)