Skip to content

Commit f905a74

Browse files
authored
Fix warning with torch.amp.custom_fwd (#163)
Fixes #134. Backward compatible to older torch versions.
1 parent c2f8561 commit f905a74

1 file changed

Lines changed: 8 additions & 1 deletion

File tree

lightglue/lightglue.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,14 @@
2121
torch.backends.cudnn.deterministic = True
2222

2323

24-
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
24+
AMP_CUSTOM_FWD_F32 = (
25+
torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
26+
if hasattr(torch, "amp")
27+
else torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
28+
)
29+
30+
31+
@AMP_CUSTOM_FWD_F32
2532
def normalize_keypoints(
2633
kpts: torch.Tensor, size: Optional[torch.Tensor] = None
2734
) -> torch.Tensor:

0 commit comments

Comments
 (0)