Skip to content

Commit 269deb3

Browse files
authored
Update to use the new attribute setting for tf32. (#835)
1 parent 94d73aa commit 269deb3

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

helion/_testing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,8 +458,8 @@ def run_example(
458458
atol: Absolute tolerance for correctness check (default: 1e-1)
459459
bwd: Whether to also test backward pass (default: False)
460460
"""
461-
torch.backends.cuda.matmul.fp32_precision = "tf32"
462-
torch.backends.cudnn.conv.fp32_precision = "tf32" # type: ignore[reportAttributeAccessIssue]
461+
torch.backends.cuda.matmul.allow_tf32 = True
462+
torch.backends.cudnn.allow_tf32 = True
463463

464464
# Normalize to dict format
465465
kernels = kernel_fn if isinstance(kernel_fn, dict) else {kernel_name: kernel_fn}

helion/autotuner/base_cache.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def torch_key_wrapper() -> str:
6565

6666
@functools.cache
6767
def triton_key_wrapper() -> str:
68-
from torch._inductor.runtime.triton_compat import triton_key
68+
from torch._inductor.runtime.triton_compat import (
69+
triton_key, # pyright: ignore[reportAttributeAccessIssue]
70+
)
6971

7072
return triton_key()
7173

helion/autotuner/local_cache.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ def _generate_key(self) -> LooseAutotuneCacheKey:
5555
for arg in self.args:
5656
if isinstance(arg, torch.Tensor):
5757
nms = torch.xpu if torch.xpu.is_available() else torch.cuda
58-
device_properties = nms.get_device_properties(arg.device)
58+
device_properties = nms.get_device_properties(arg.device) # pyright: ignore[reportAttributeAccessIssue]
5959
if torch.version.cuda is not None: # pyright: ignore[reportAttributeAccessIssue]
6060
hardware = device_properties.name
61-
runtime_name = str(torch.version.cuda)
61+
runtime_name = str(torch.version.cuda) # pyright: ignore[reportAttributeAccessIssue]
6262
elif torch.version.hip is not None: # pyright: ignore[reportAttributeAccessIssue]
63-
hardware = device_properties.gcnArchName
63+
hardware = device_properties.gcnArchName # pyright: ignore[reportAttributeAccessIssue]
6464
runtime_name = torch.version.hip # pyright: ignore[reportAttributeAccessIssue]
6565
else:
6666
hardware = device_properties.name

0 commit comments

Comments
 (0)