diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 472d2a36dc32..7f158a5b72c9 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -165,14 +165,27 @@ def load_dialects(self, ctx): if HIPBackend.instrumentation: HIPBackend.instrumentation.load_dialects(ctx) + # is_within_2gb() needs to check for a torch subobject and this var tracks torch + # availability state: None - not tested, True - torch is present. Anything else - + # no torch available. First call to is_within_2gb() checks torch availability + # and caches it. + _torch_available: None | bool = None + @staticmethod def is_within_2gb(arg): - import torch + if HIPBackend._torch_available is None: + try: + import torch + HIPBackend._torch_available = True + except ImportError: + HIPBackend._torch_available = False + elif HIPBackend._torch_available: + import torch MAX_INT_32 = 2**31 - 1 if hasattr(arg, "ptr_range"): return arg.ptr_range() <= MAX_INT_32 - if isinstance(arg, torch.Tensor) and hasattr(arg, "untyped_storage"): + if HIPBackend._torch_available and isinstance(arg, torch.Tensor) and hasattr(arg, "untyped_storage"): return arg.untyped_storage().size() <= MAX_INT_32 return False