Skip to content

Commit 10dc329

Browse files
Make the remaining cache tests device agnostic (#9528)
Currently these tests are written in a way that isn't device agnostic. To fix: * update the key into device_caches to be `torch.device.current_device()` like the rest of the file. * add the pytest fixture device to `test_module_load_unload` # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `it updates existing tests`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 15f9161 commit 10dc329

1 file changed

Lines changed: 12 additions & 7 deletions

File tree

python/test/unit/runtime/test_cache.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,9 @@ def kernel(Y, fn: tl.constexpr, fn_args):
693693
kernel[(1, )](y[2], func3, (3, ))
694694
kernel[(1, )](y[3], func4, (3, 4))
695695
kernel[(1, )](y[4], func1, tuple())
696-
assert len(kernel.device_caches[0][0]) == 4
696+
697+
device = getattr(torch, device).current_device()
698+
assert len(kernel.device_caches[device][0]) == 4
697699
assert y.tolist() == [1, 2, 3, 7, 1]
698700

699701

@@ -747,19 +749,21 @@ def kernel(Y, a: tl.constexpr):
747749
kernel.warmup(b, 0, grid=(1, ))
748750
kernel.warmup(b, 1, grid=(1, ))
749751

752+
device = getattr(torch, device).current_device()
753+
750754
# Nothing has actually compiled yet
751-
assert len(kernel.device_caches[0][0]) == 4
755+
assert len(kernel.device_caches[device][0]) == 4
752756
assert len(pool.work_queue) == 4
753757

754758
# Duplicates are only submitted once
755759
kernel.warmup(a, 0, grid=(1, ))
756760
kernel.warmup(a, 1, grid=(1, ))
757-
assert len(kernel.device_caches[0][0]) == 4
761+
assert len(kernel.device_caches[device][0]) == 4
758762
assert len(pool.work_queue) == 4
759763

760764
pool.run_one()
761765
kernel[(1, )](a, 0)
762-
assert len(kernel.device_caches[0][0]) == 4
766+
assert len(kernel.device_caches[device][0]) == 4
763767
assert a[0, 0] == 0.0
764768

765769
pool.run_all()
@@ -782,7 +786,8 @@ def kernel(Y, a: tl.constexpr):
782786
kernel.warmup(b, 0, grid=(1, ))
783787
kernel.warmup(b, 1, grid=(1, ))
784788

785-
assert len(kernel.device_caches[0][0]) == 4
789+
device = getattr(torch, device).current_device()
790+
assert len(kernel.device_caches[device][0]) == 4
786791

787792
kernel[(1, )](b, 1)
788793
assert b[0, 0] == 1
@@ -894,7 +899,7 @@ def inc_counter(*args, **kwargs):
894899
assert output.item() == 31
895900

896901

897-
def test_module_load_unload(fresh_knobs):
902+
def test_module_load_unload(device, fresh_knobs):
898903

899904
@triton.jit
900905
def kernel(out_ptr, val) -> None:
@@ -912,7 +917,7 @@ def module_unload(*args, **kwargs):
912917
gc.disable()
913918
triton.knobs.runtime.module_unload_hook.add(module_unload)
914919

915-
out = torch.randn(1, dtype=torch.float32, device='cuda')
920+
out = torch.randn(1, dtype=torch.float32, device=device)
916921
pre_compile = kernel.warmup(out, 1, grid=(1, ))
917922
pre_compile._init_handles()
918923

0 commit comments

Comments
 (0)