Skip to content

Commit eea33ff

Browse files
authored
[TEST] Update fresh_knobs fixture default behavior (#9184)
Refactor the default behavior of fresh_knobs in test fixtures. `fresh_knobs` (default): Now preserves library paths (build, nvidia, amd knobs) Most tests need CUDA toolkit paths to compile successfully Respects environment variables like `TRITON_PTXAS_BLACKWELL_PATH` `fresh_knobs_including_libraries` (new): Resets ALL knobs including library paths
1 parent ebc4992 commit eea33ff

7 files changed

Lines changed: 47 additions & 21 deletions

File tree

python/test/conftest.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,26 @@ def fresh_triton_cache():
2727

2828
@pytest.fixture
2929
def fresh_knobs():
30+
"""
31+
Resets all knobs except ``build``, ``nvidia``, and ``amd`` (preserves
32+
library paths needed to compile kernels).
33+
"""
3034
from triton._internal_testing import _fresh_knobs_impl
31-
fresh_function, reset_function = _fresh_knobs_impl()
35+
fresh_function, reset_function = _fresh_knobs_impl(skipped_attr={"build", "nvidia", "amd"})
3236
try:
3337
yield fresh_function()
3438
finally:
3539
reset_function()
3640

3741

3842
@pytest.fixture
39-
def fresh_knobs_except_libraries():
43+
def fresh_knobs_including_libraries():
4044
"""
41-
A variant of `fresh_knobs` that keeps library path
42-
information from the environment as these may be
43-
needed to successfully compile kernels.
45+
Resets ALL knobs including ``build``, ``nvidia``, and ``amd``.
46+
Use for tests that verify initial values of these knobs.
4447
"""
4548
from triton._internal_testing import _fresh_knobs_impl
46-
fresh_function, reset_function = _fresh_knobs_impl(skipped_attr={"build", "nvidia", "amd"})
49+
fresh_function, reset_function = _fresh_knobs_impl()
4750
try:
4851
yield fresh_function()
4952
finally:

python/test/unit/language/test_core.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5745,7 +5745,7 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s
57455745

57465746
@pytest.mark.parametrize("enable_fp_fusion", [False, True])
57475747
@pytest.mark.parametrize("default_override", [False, True])
5748-
def test_enable_fp_fusion(enable_fp_fusion, default_override, device, fresh_knobs_except_libraries):
5748+
def test_enable_fp_fusion(enable_fp_fusion, default_override, device, fresh_knobs):
57495749
# Sequential multiply add can be fused by backend
57505750
@triton.jit
57515751
def mul_add(data):
@@ -5754,7 +5754,7 @@ def mul_add(data):
57545754

57555755
data = torch.randn((128, ), device=device, dtype=torch.float32)
57565756
if default_override:
5757-
fresh_knobs_except_libraries.language.default_fp_fusion = enable_fp_fusion
5757+
fresh_knobs.language.default_fp_fusion = enable_fp_fusion
57585758
h = mul_add.warmup(data, grid=(1, ))
57595759
else:
57605760
h = mul_add.warmup(data, grid=(1, ), enable_fp_fusion=enable_fp_fusion)
@@ -5772,7 +5772,7 @@ def mul_add(data):
57725772

57735773
@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA")
57745774
@pytest.mark.parametrize("enable_reflect_ftz", [False, True])
5775-
def test_enable_reflect_ftz(enable_reflect_ftz, device, fresh_knobs_except_libraries):
5775+
def test_enable_reflect_ftz(enable_reflect_ftz, device, fresh_knobs):
57765776

57775777
@triton.jit
57785778
def exp2(data):
@@ -5793,7 +5793,7 @@ def exp2(data):
57935793

57945794
@pytest.mark.parametrize("arch", ["sm70", "sm80", "sm90", "gfx942", "gfx950", "gfx1200"])
57955795
@pytest.mark.parametrize("env_var_override", [False, True])
5796-
def test_override_arch(arch, env_var_override, device, fresh_knobs_except_libraries):
5796+
def test_override_arch(arch, env_var_override, device, fresh_knobs):
57975797
if arch.startswith("sm") and not is_cuda():
57985798
pytest.skip(f"{arch} arch only for CUDA")
57995799
elif arch.startswith("gfx") and not is_hip():
@@ -5810,7 +5810,7 @@ def simple(data, out):
58105810

58115811
if is_cuda():
58125812
if env_var_override:
5813-
fresh_knobs_except_libraries.runtime.override_arch = str(arch)
5813+
fresh_knobs.runtime.override_arch = str(arch)
58145814
h = simple.warmup(data, out, grid=(1, ))
58155815
else:
58165816
h = simple.warmup(data, out, arch=arch, grid=(1, ))
@@ -5820,7 +5820,7 @@ def simple(data, out):
58205820
# For HIP, the generated kernel is a binary containing the final ISA. So we cannot run
58215821
# them like CUDA side if the chip doesn't match. Here we just check generated ISA.
58225822
if env_var_override:
5823-
fresh_knobs_except_libraries.runtime.override_arch = str(arch)
5823+
fresh_knobs.runtime.override_arch = str(arch)
58245824
h = simple.warmup(data, out, grid=(1, ))
58255825
else:
58265826
h = simple.warmup(data, out, arch=arch, grid=(1, ))

python/test/unit/language/test_tuple.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,15 +269,15 @@ def test_passing_nested_tuple_with_constexpr(device):
269269
_nested_tuple_kernel[(1, )](((1, ), (tl.constexpr(2), )))
270270

271271

272-
def test_passing_nested_tuple_with_constexpr_and_jit_hook(device, fresh_knobs_except_libraries):
272+
def test_passing_nested_tuple_with_constexpr_and_jit_hook(device, fresh_knobs):
273273
# get the serialized specialization data
274274
specialization_data = None
275275

276276
def cache_hook(*args, **kwargs):
277277
nonlocal specialization_data
278278
specialization_data = kwargs["compile"]["specialization_data"]
279279

280-
fresh_knobs_except_libraries.runtime.jit_cache_hook = cache_hook
280+
fresh_knobs.runtime.jit_cache_hook = cache_hook
281281

282282
device = getattr(torch, device).current_device()
283283

python/test/unit/runtime/test_build.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test_compile_module(fresh_triton_cache):
6565
assert mod2.__file__ == mod.__file__
6666

6767

68-
def test_compile_module_bad_cache(fresh_knobs_except_libraries):
68+
def test_compile_module_bad_cache(fresh_knobs):
6969
with tempfile.TemporaryDirectory() as tmpd:
7070
tmp = Path(tmpd)
7171
called_get_file = False
@@ -79,7 +79,7 @@ def get_file(self, filename: str) -> str | None:
7979
return str(tmp / filename)
8080

8181
# First corrupt the cache
82-
fresh_knobs_except_libraries.cache.manager_class = InvalidFileCacheManager
82+
fresh_knobs.cache.manager_class = InvalidFileCacheManager
8383

8484
mod = compile_module_from_src(TEST_MODULE_C, "test_module")
8585
assert called_get_file

python/test/unit/runtime/test_compilation_listener.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def cumsum_kernel(ptr):
1717
tl.store(block, tl.cumsum(x, 0))
1818

1919

20-
def test_compile_stats(device: str, fresh_knobs_except_libraries: Any, fresh_triton_cache: str) -> None:
20+
def test_compile_stats(device: str, fresh_knobs: Any, fresh_triton_cache: str) -> None:
2121
captured: Union[tuple[Union[ASTSource, IRSource], dict[str, Any], dict[str, Any], CompileTimes, bool], None] = None
2222

2323
def compile_listener(src: Union[ASTSource, IRSource], metadata: dict[str, str], metadata_group: dict[str, Any],
@@ -26,7 +26,7 @@ def compile_listener(src: Union[ASTSource, IRSource], metadata: dict[str, str],
2626
assert captured is None
2727
captured = (src, metadata, metadata_group, times, cache_hit)
2828

29-
fresh_knobs_except_libraries.compilation.listener = compile_listener
29+
fresh_knobs.compilation.listener = compile_listener
3030

3131
x = torch.randn(4, device=device)
3232
cumsum_kernel[(1, )](x)

python/test/unit/test_knobs.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ def test_env_updated(fresh_knobs, monkeypatch):
113113

114114
@pytest.mark.parametrize("truthy, falsey", [("1", "0"), ("true", "false"), ("True", "False"), ("TRUE", "FALSE"),
115115
("y", "n"), ("YES", "NO"), ("ON", "OFF")])
116-
def test_read_env(truthy, falsey, fresh_knobs, monkeypatch):
116+
def test_read_env(truthy, falsey, fresh_knobs_including_libraries, monkeypatch):
117+
fresh_knobs = fresh_knobs_including_libraries
117118
# bool defaulting to False
118119
assert not fresh_knobs.runtime.debug
119120
# bool defaulting to True
@@ -170,7 +171,8 @@ def test_triton_home(fresh_knobs, monkeypatch):
170171
assert fresh_knobs.cache.override_dir == "/tmp/user/triton_home/.triton/override"
171172

172173

173-
def test_set_knob_directly(fresh_knobs, monkeypatch):
174+
def test_set_knob_directly(fresh_knobs_including_libraries, monkeypatch):
175+
fresh_knobs = fresh_knobs_including_libraries
174176
assert fresh_knobs.cache.dir.endswith(".triton/cache")
175177

176178
fresh_knobs.cache.dir = "/tmp/triton_cache"
@@ -279,7 +281,8 @@ def test_nvidia_tool(fresh_knobs, tmp_path, monkeypatch):
279281
assert fresh_knobs.nvidia.ptxas_options is None
280282

281283

282-
def test_opt_bool(fresh_knobs, monkeypatch):
284+
def test_opt_bool(fresh_knobs_including_libraries, monkeypatch):
285+
fresh_knobs = fresh_knobs_including_libraries
283286
assert fresh_knobs.amd.use_block_pingpong is None
284287
monkeypatch.setenv("TRITON_HIP_USE_BLOCK_PINGPONG", "0")
285288
assert not fresh_knobs.amd.use_block_pingpong

python/triton_kernels/tests/conftest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,26 @@ def device(request):
1414

1515
@pytest.fixture
1616
def fresh_knobs():
17+
"""
18+
Default fresh knobs fixture that preserves library path
19+
information from the environment as these are typically
20+
needed to successfully compile kernels.
21+
"""
22+
from triton._internal_testing import _fresh_knobs_impl
23+
fresh_function, reset_function = _fresh_knobs_impl(skipped_attr={"build", "nvidia", "amd"})
24+
try:
25+
yield fresh_function()
26+
finally:
27+
reset_function()
28+
29+
30+
@pytest.fixture
31+
def fresh_knobs_including_libraries():
32+
"""
33+
A variant of `fresh_knobs` that resets ALL knobs including
34+
library paths. Use this only for tests that need complete
35+
environment isolation.
36+
"""
1737
from triton._internal_testing import _fresh_knobs_impl
1838
fresh_function, reset_function = _fresh_knobs_impl()
1939
try:

0 commit comments

Comments
 (0)