We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent c2f8561 commit f905a74Copy full SHA for f905a74
1 file changed
lightglue/lightglue.py
@@ -21,7 +21,14 @@
21
torch.backends.cudnn.deterministic = True
22
23
24
-@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+AMP_CUSTOM_FWD_F32 = (
25
+ torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
26
+ if hasattr(torch, "amp")
27
+ else torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
28
+)
29
+
30
31
+@AMP_CUSTOM_FWD_F32
32
def normalize_keypoints(
33
kpts: torch.Tensor, size: Optional[torch.Tensor] = None
34
) -> torch.Tensor:
0 commit comments