Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,27 @@ def load_dialects(self, ctx):
if HIPBackend.instrumentation:
HIPBackend.instrumentation.load_dialects(ctx)

# is_within_2gb() needs to check for a torch subobject and this var tracks torch
# availability state: None - not tested, True - torch is present. Anything else -
# no torch available. First call to is_within_2gb() checks torch availability
# and caches it.
_torch_available: None | bool = None

@staticmethod
def is_within_2gb(arg):
import torch
if HIPBackend._torch_available is None:
try:
import torch
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's just better to do a guarded import torch on a file level, at the top. This would at least make that dependency apparent, but it could remain optional with a similar guard flag as here.

Please tell me your preference, I can change it right here.

HIPBackend._torch_available = True
except ImportError:
HIPBackend._torch_available = False
elif HIPBackend._torch_available:
import torch

MAX_INT_32 = 2**31 - 1
if hasattr(arg, "ptr_range"):
return arg.ptr_range() <= MAX_INT_32
if isinstance(arg, torch.Tensor) and hasattr(arg, "untyped_storage"):
if HIPBackend._torch_available and isinstance(arg, torch.Tensor) and hasattr(arg, "untyped_storage"):
return arg.untyped_storage().size() <= MAX_INT_32
return False

Expand Down
Loading