11import kornia
22import torch
3- from kornia .geometry .calibration .pnp import _mean_isotropic_scale_normalize
43from torch import Tensor
54
5+ from ..utils import misc
66from . import transforms as gtr
77from .reconstruction import Pose
88
@@ -30,26 +30,29 @@ def _mean_isotropic_scale_normalize(
3030 scale = scale [:, None ] # B x 1
3131
3232 norm_t_w = (
33- torch .cat ([kornia .utils .eye_like (D_int , points ), - x_mean ], axis = - 2 ) * scale
33+ torch .cat (
34+ [kornia .utils .eye_like (D_int , points ), - x_mean .transpose (- 1 , - 2 )], dim = - 1
35+ )
36+ * scale [..., None ]
3437 )
3538
36- last_col = torch .cat (
37- [torch .zeros_like (x_mean ), torch .ones_like (x_mean [..., :1 ])], axis = - 1
38- ).transpose (
39- - 1 , - 2
40- ) # Bx(D+1)x1
39+ last_row = torch .cat (
40+ [torch .zeros_like (x_mean ), torch .ones_like (x_mean [..., :1 ])], dim = - 1
41+ )
4142 norm_t_w = torch .cat (
42- [norm_t_w , last_col ],
43- axis = - 1 ,
43+ [norm_t_w , last_row ],
44+ dim = - 2 ,
4445 ) # Bx(D+1)x(D+1)
4546
4647 points_norm = kornia .geometry .linalg .transform_points (norm_t_w , points ) # BxNxD
4748 return (points_norm , norm_t_w )
4849
4950
50- def pnp_dlt (p3d_w : Tensor , p2d_c : Tensor ) -> Pose :
51+ @misc .AMP_CUSTOM_FWD_F32
52+ def pnp_dlt (p3d_w : Tensor , p2d_c : Tensor , weights : Tensor | None = None ) -> Pose :
5153 # p3d_w: (B, N, 3) - 3D points in world coordinates
5254 # p2d_c: (B, N, 2) - 2D points in camera coordinates (normalized)
55+ # weights: (B, N) - weights for each point correspondence
5356 B , N = p3d_w .shape [:2 ]
5457
5558 p3d_w_norm , world_transform_norm = _mean_isotropic_scale_normalize (p3d_w )
@@ -72,6 +75,10 @@ def pnp_dlt(p3d_w: Tensor, p2d_c: Tensor) -> Pose:
7275 dim = - 1 ,
7376 )
7477
78+ if weights is not None :
79+ weights = weights .repeat_interleave (2 , dim = 1 ).sqrt () # (B, 2N)
80+ system = system * weights [..., None ] # Apply weights to the system
81+
7582 # Getting the solution vectors.
7683 _ , _ , v = torch .svd (system )
7784 solution = v [..., - 1 ]
@@ -87,7 +94,6 @@ def pnp_dlt(p3d_w: Tensor, p2d_c: Tensor) -> Pose:
8794 )
8895 # Creating solution_4x4
8996 solution_4x4 = torch .cat ([solution , last_row ], dim = - 2 )
90- print (solution_4x4 )
9197
9298 # De-normalizing the solution
9399 intermediate = torch .bmm (solution_4x4 , world_transform_norm )
0 commit comments