-
Notifications
You must be signed in to change notification settings - Fork 58
Expand file tree
/
Copy pathtest_debug_dump.py
More file actions
51 lines (40 loc) · 1.4 KB
/
test_debug_dump.py
File metadata and controls
51 lines (40 loc) · 1.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
from contextlib import contextmanager
import torch
import triton
import triton.language as tl
@contextmanager
def enable_dump_context(pass_name="1"):
try:
os.environ["MLIR_ENABLE_DUMP"] = pass_name
yield
finally:
os.environ["MLIR_ENABLE_DUMP"] = "0"
def test_fn_dump(capfd, device, fresh_triton_cache):
return # TODO: flagtree
N = 1024
src = torch.zeros(N, device=device)
grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]), )
@triton.jit
def _kernel(src, N, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(src + offsets, mask=offsets < N) + 1
tl.store(src + offsets, x, mask=offsets < N)
with enable_dump_context():
BLOCK_SIZE = 16
_kernel[grid](src, N, BLOCK_SIZE)
captured = capfd.readouterr()
print(captured.err)
assert "IR Dump Before" in captured.err
assert "tt.func public @_kernel" in captured.err
with enable_dump_context("_kernel"):
BLOCK_SIZE = 32
_kernel[grid](src, N, BLOCK_SIZE)
captured = capfd.readouterr()
assert "IR Dump Before" in captured.err
assert "tt.func public @_kernel" in captured.err
with enable_dump_context("_kernel2"):
BLOCK_SIZE = 64
_kernel[grid](src, N, BLOCK_SIZE)
captured = capfd.readouterr()
assert "IR Dump Before" not in captured.err