|
2 | 2 | import torch |
3 | 3 | import triton.language as tl |
4 | 4 | import triton |
| 5 | +import sys |
| 6 | +import subprocess |
| 7 | +import os |
5 | 8 |
|
6 | 9 |
|
7 | 10 | @pytest.mark.parametrize('cond', [True, False]) |
|
10 | 13 | @pytest.mark.parametrize('env_var', [True, False]) |
11 | 14 | @pytest.mark.parametrize('jit_flag', [True, False]) |
12 | 15 | @pytest.mark.forked |
13 | | -def test_device_assert(monkeypatch, cond, mask, opt_flag, env_var, jit_flag, device): |
14 | | - monkeypatch.setenv("TRITON_DEBUG", str(int(env_var))) |
15 | | - triton.knobs.refresh_knobs() |
16 | | - torch.zeros([1], dtype=torch.int32, device=device) |
17 | | - |
18 | | - @triton.jit(debug=jit_flag) |
19 | | - def _kernel(COND: tl.constexpr, MASK: tl.constexpr): |
20 | | - tl.device_assert(COND, 'test', mask=MASK) |
| 16 | +def test_device_assert(cond, mask, opt_flag, env_var, jit_flag, device): |
| 17 | + """Temporary subprocess solution due to: |
| 18 | + https://github.com/pytorch/pytorch/issues/142135""" |
21 | 19 |
|
22 | 20 | is_debug = env_var or (opt_flag if opt_flag is not None else jit_flag) |
23 | 21 |
|
24 | | - kwargs = {} |
25 | | - if opt_flag is not None: |
26 | | - kwargs["debug"] = opt_flag |
27 | | - |
28 | | - if not cond and is_debug and mask is not False: |
29 | | - with pytest.raises(RuntimeError): |
30 | | - _kernel[(1, )](cond, mask, **kwargs) |
31 | | - getattr(torch, device).synchronize() |
32 | | - return |
33 | | - |
34 | | - _kernel[(1, )](cond, mask, **kwargs) |
35 | | - getattr(torch, device).synchronize() |
| 22 | + should_fail = not cond and is_debug and mask is not False |
| 23 | + kernel_file = os.path.join(os.path.dirname(__file__), "test_debug_kernels.py") |
| 24 | + mask_str = "None" if mask is None else str(mask) |
| 25 | + opt_flag_str = "None" if opt_flag is None else str(opt_flag) |
| 26 | + |
| 27 | + env = os.environ.copy() |
| 28 | + env["TRITON_DEBUG"] = str(int(env_var)) |
| 29 | + |
| 30 | + result = subprocess.run( |
| 31 | + [sys.executable, kernel_file, "device_assert", |
| 32 | + str(cond), mask_str, opt_flag_str, |
| 33 | + str(jit_flag), device], capture_output=True, text=True, env=env) |
| 34 | + |
| 35 | + if should_fail: |
| 36 | + if device == 'xpu': |
| 37 | + assert result.returncode == -6, (f"Expected SIGABRT but got exit code {result.returncode}. " |
| 38 | + f"stdout: {result.stdout}, stderr: {result.stderr}") |
| 39 | + else: |
| 40 | + assert result.returncode == 1, (f"Expected runtime error but got unexpected exit code {result.returncode}. " |
| 41 | + f"stdout: {result.stdout}, stderr: {result.stderr}") |
| 42 | + else: |
| 43 | + assert result.returncode == 0, (f"Expected success but got unexpected exit code {result.returncode}. " |
| 44 | + f"stdout: {result.stdout}, stderr: {result.stderr}") |
36 | 45 |
|
37 | 46 |
|
38 | 47 | def test_device_assert_barrier(monkeypatch, device): |
|
0 commit comments