1111
1212import pytest
1313import torch
14+ import torch .nn as nn
15+ import pypose as pp
1416
1517_REPO_ROOT = Path (__file__ ).resolve ().parents [2 ]
1618if str (_REPO_ROOT ) not in sys .path :
1719 sys .path .insert (0 , str (_REPO_ROOT ))
1820
19- from ba_helpers import Reproj # noqa: E402
21+ from ba_helpers import Reproj , project # noqa: E402
22+ from bae .autograd .function import TrackingTensor , map_transform
2023import bae .autograd .graph as autograd_graph # noqa: E402
2124from datapipes .bal_io import read_bal_data # noqa: E402
2225
@@ -145,6 +148,17 @@ def _jtj_diag_from_bsr(J: torch.Tensor) -> torch.Tensor:
145148 return diag_blocks .flatten ()
146149
147150
151+ def _assert_coo_no_empty_columns (J : torch .Tensor ) -> None :
152+ assert J .layout == torch .sparse_coo
153+ J = J .coalesce ()
154+ n_cols = int (J .shape [1 ])
155+ if n_cols == 0 :
156+ return
157+ cols = J .indices ()[1 ].to (torch .int64 )
158+ counts = torch .bincount (cols , minlength = n_cols )
159+ assert (counts > 0 ).all ()
160+
161+
148162def _assert_bal_correctness_criteria (
149163 J_cam : torch .Tensor ,
150164 J_pts : torch .Tensor ,
@@ -233,6 +247,7 @@ def test_bal_jacobian_structure_no_empty_columns(
233247
234248 model = Reproj (camera_params .clone (), points_3d .clone ()).to (device )
235249 residual = model (points_2d , camera_idx , point_idx )
250+ n_obs = int (points_2d .shape [0 ])
236251
237252 J_cam , J_pts = autograd_graph .jacobian (residual , [model .pose , model .points_3d ])
238253 assert J_cam .layout == torch .sparse_bsr
@@ -241,6 +256,9 @@ def test_bal_jacobian_structure_no_empty_columns(
241256 n_cams = model .pose .shape [0 ]
242257 n_pts = model .points_3d .shape [0 ]
243258
259+ assert J_cam .shape == (n_obs * 2 , n_cams * 9 )
260+ assert J_pts .shape == (n_obs * 2 , n_pts * 3 )
261+
244262 _assert_bal_correctness_criteria (
245263 J_cam ,
246264 J_pts ,
@@ -250,6 +268,92 @@ def test_bal_jacobian_structure_no_empty_columns(
250268 n_pts = n_pts ,
251269 )
252270
271+ J_full = torch .cat ([t .to_sparse_coo () for t in (J_cam , J_pts )], dim = - 1 )
272+ _assert_coo_no_empty_columns (J_full )
273+
274+
275+
276+ @map_transform
277+ def transform_points (points , se3_params ):
278+ return pp .SE3 (se3_params ).Act (points )
279+
280+
281+ class ReprojCat (nn .Module ):
282+ def __init__ (self , camera_params , points_b , points_c , se3_c ):
283+ super ().__init__ ()
284+ self .pose = nn .Parameter (TrackingTensor (camera_params ))
285+ self .points_b = nn .Parameter (TrackingTensor (points_b ))
286+ self .points_c = nn .Parameter (TrackingTensor (points_c ))
287+ self .se3_c = nn .Parameter (TrackingTensor (se3_c ))
288+ self .pose .trim_SE3_grad = True
289+ self .se3_c .trim_SE3_grad = True
290+
291+ def forward (self , points_2d , camera_indices , point_indices ):
292+ points_c = transform_points (self .points_c , self .se3_c )
293+ points_all = torch .cat ([self .points_b , points_c ], dim = 0 )
294+ points_proj = project (points_all [point_indices ], self .pose [camera_indices ])
295+ return points_proj - points_2d
296+
297+
298+ @pytest .mark .parametrize (
299+ ("dataset" , "problem_name" ),
300+ _BAL_SAMPLES ,
301+ ids = [f"{ ds } .{ name } " for ds , name in _BAL_SAMPLES ],
302+ )
303+ def test_bal_jacobian_cat_split_points_no_empty_columns (
304+ dataset : str ,
305+ problem_name : str ,
306+ bal_cache_dir : Path ,
307+ ):
308+ data = _load_bal_problem (dataset , problem_name , bal_cache_dir )
309+
310+ device = torch .device ("cpu" )
311+ dtype = torch .float64
312+
313+ camera_params = data ["camera_params" ].to (device = device , dtype = dtype )
314+ points_3d = data ["points_3d" ].to (device = device , dtype = dtype )
315+ points_2d = data ["points_2d" ].to (device = device , dtype = dtype )
316+ camera_idx = data ["camera_index_of_observations" ].to (torch .int32 ).to (device = device )
317+ point_idx = data ["point_index_of_observations" ].to (torch .int32 ).to (device = device )
318+
319+ n_pts = int (points_3d .shape [0 ])
320+ split = max (1 , n_pts // 2 )
321+ if split >= n_pts :
322+ pytest .skip ("BAL sample has <2 points; cannot construct cat split case." )
323+
324+ points_b = points_3d [:split ].clone ()
325+ points_c = points_3d [split :].clone ()
326+
327+ torch .manual_seed (0 )
328+ se3_c = pp .randn_SE3 (points_c .shape [0 ], device = device , dtype = dtype ).tensor ()
329+
330+ model = ReprojCat (camera_params .clone (), points_b , points_c , se3_c ).to (device )
331+ residual = model (points_2d , camera_idx , point_idx )
332+ n_obs = int (points_2d .shape [0 ])
333+
334+ J_cam , J_b , J_c , J_se3 = autograd_graph .jacobian (
335+ residual ,
336+ [model .pose , model .points_b , model .points_c , model .se3_c ],
337+ )
338+
339+ n_cams = model .pose .shape [0 ]
340+ n_b = model .points_b .shape [0 ]
341+ n_c = model .points_c .shape [0 ]
342+
343+ assert J_cam .shape == (n_obs * 2 , n_cams * 9 )
344+ assert J_b .shape == (n_obs * 2 , n_b * 3 )
345+ assert J_c .shape == (n_obs * 2 , n_c * 3 )
346+ assert J_se3 .shape == (n_obs * 2 , n_c * 6 )
347+
348+ J_full = torch .cat (
349+ [t .to_sparse_coo () for t in (J_cam , J_b , J_c , J_se3 )],
350+ dim = - 1 ,
351+ ).coalesce ()
352+ _assert_coo_no_empty_columns (J_full )
353+ diag = torch .zeros (J_full .shape [1 ], dtype = J_full .dtype , device = J_full .device )
354+ diag .scatter_add_ (0 , J_full .indices ()[1 ].to (torch .int64 ), J_full .values ().square ())
355+ assert (diag > 0 ).all ()
356+
253357
254358@pytest .mark .parametrize (
255359 ("dataset" , "problem_name" ),
0 commit comments