Skip to content

Commit 7498c43

Browse files
authored
Use torch api in megadepth rotation augmentation (#100)
1 parent c06b472 commit 7498c43

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

gluefactory/datasets/megadepth.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,9 +283,9 @@ def _read_view(self, scene, idx):
283283
k = 0
284284
if np.random.rand() < p:
285285
k = np.random.choice(2, 1, replace=False)[0] * 2 - 1
286-
img = np.rot90(img, k=-k, axes=(-2, -1))
286+
img = torch.rot90(img, k=-k, dims=[1, 2])
287287
if self.conf.read_depth:
288-
depth = np.rot90(depth, k=-k, axes=(-2, -1)).copy()
288+
depth = torch.rot90(depth, k=-k, dims=[1, 2]).clone()
289289
K = rotate_intrinsics(K, img.shape, k + 2)
290290
T = rotate_pose_inplane(T, k + 2)
291291

@@ -311,8 +311,8 @@ def _read_view(self, scene, idx):
311311
features = self.feature_loader({k: [v] for k, v in data.items()})
312312
if do_rotate and k != 0:
313313
# ang = np.deg2rad(k * 90.)
314-
kpts = features["keypoints"].copy()
315-
x, y = kpts[:, 0].copy(), kpts[:, 1].copy()
314+
kpts = features["keypoints"].clone()
315+
x, y = kpts[:, 0].clone(), kpts[:, 1].clone()
316316
w, h = data["image_size"]
317317
if k == 1:
318318
kpts[:, 0] = w - y

0 commit comments

Comments
 (0)