Skip to content

Commit f3b9504

Browse files
committed
Fixees in absolute pose dlt
1 parent 2d60ff2 commit f3b9504

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

gluefactory/geometry/absolute_pose.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import kornia
22
import torch
3-
from kornia.geometry.calibration.pnp import _mean_isotropic_scale_normalize
43
from torch import Tensor
54

5+
from ..utils import misc
66
from . import transforms as gtr
77
from .reconstruction import Pose
88

@@ -30,26 +30,29 @@ def _mean_isotropic_scale_normalize(
3030
scale = scale[:, None] # B x 1
3131

3232
norm_t_w = (
33-
torch.cat([kornia.utils.eye_like(D_int, points), -x_mean], axis=-2) * scale
33+
torch.cat(
34+
[kornia.utils.eye_like(D_int, points), -x_mean.transpose(-1, -2)], dim=-1
35+
)
36+
* scale[..., None]
3437
)
3538

36-
last_col = torch.cat(
37-
[torch.zeros_like(x_mean), torch.ones_like(x_mean[..., :1])], axis=-1
38-
).transpose(
39-
-1, -2
40-
) # Bx(D+1)x1
39+
last_row = torch.cat(
40+
[torch.zeros_like(x_mean), torch.ones_like(x_mean[..., :1])], dim=-1
41+
)
4142
norm_t_w = torch.cat(
42-
[norm_t_w, last_col],
43-
axis=-1,
43+
[norm_t_w, last_row],
44+
dim=-2,
4445
) # Bx(D+1)x(D+1)
4546

4647
points_norm = kornia.geometry.linalg.transform_points(norm_t_w, points) # BxNxD
4748
return (points_norm, norm_t_w)
4849

4950

50-
def pnp_dlt(p3d_w: Tensor, p2d_c: Tensor) -> Pose:
51+
@misc.AMP_CUSTOM_FWD_F32
52+
def pnp_dlt(p3d_w: Tensor, p2d_c: Tensor, weights: Tensor | None = None) -> Pose:
5153
# p3d_w: (B, N, 3) - 3D points in world coordinates
5254
# p2d_c: (B, N, 2) - 2D points in camera coordinates (normalized)
55+
# weights: (B, N) - weights for each point correspondence
5356
B, N = p3d_w.shape[:2]
5457

5558
p3d_w_norm, world_transform_norm = _mean_isotropic_scale_normalize(p3d_w)
@@ -72,6 +75,10 @@ def pnp_dlt(p3d_w: Tensor, p2d_c: Tensor) -> Pose:
7275
dim=-1,
7376
)
7477

78+
if weights is not None:
79+
weights = weights.repeat_interleave(2, dim=1).sqrt() # (B, 2N)
80+
system = system * weights[..., None] # Apply weights to the system
81+
7582
# Getting the solution vectors.
7683
_, _, v = torch.svd(system)
7784
solution = v[..., -1]
@@ -87,7 +94,6 @@ def pnp_dlt(p3d_w: Tensor, p2d_c: Tensor) -> Pose:
8794
)
8895
# Creating solution_4x4
8996
solution_4x4 = torch.cat([solution, last_row], dim=-2)
90-
print(solution_4x4)
9197

9298
# De-normalizing the solution
9399
intermediate = torch.bmm(solution_4x4, world_transform_norm)

0 commit comments

Comments
 (0)