Skip to content

Commit b587cd1

Browse files
committed
[test] add real bal example for cat
1 parent 812382c commit b587cd1

File tree

3 files changed

+149
-2
lines changed

3 files changed

+149
-2
lines changed

bae/autograd/graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,13 @@ def _slice_upstream_tuple_columns(
6464

6565
if indices is None:
6666
indices = torch.arange(n_rows_blocks, device=values.device, dtype=torch.int32)
67+
elif indices.device != values.device:
68+
indices = indices.to(device=values.device)
6769

6870
mask = (indices >= col_start) & (indices < col_end)
6971
crow = torch.zeros(n_rows_blocks + 1, device=values.device, dtype=torch.int32)
7072
crow[1:] = torch.cumsum(mask.to(crow.dtype), dim=0)
71-
col_f = (indices[mask] - col_start).to(torch.int32)
73+
col_f = (indices[mask] - col_start).to(device=values.device, dtype=torch.int32)
7274
val_f = values[mask]
7375

7476
return torch.sparse_bsr_tensor(

tests/autograd/test_bal_jacobian.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111

1212
import pytest
1313
import torch
14+
import torch.nn as nn
15+
import pypose as pp
1416

1517
_REPO_ROOT = Path(__file__).resolve().parents[2]
1618
if str(_REPO_ROOT) not in sys.path:
1719
sys.path.insert(0, str(_REPO_ROOT))
1820

19-
from ba_helpers import Reproj # noqa: E402
21+
from ba_helpers import Reproj, project # noqa: E402
22+
from bae.autograd.function import TrackingTensor, map_transform
2023
import bae.autograd.graph as autograd_graph # noqa: E402
2124
from datapipes.bal_io import read_bal_data # noqa: E402
2225

@@ -145,6 +148,17 @@ def _jtj_diag_from_bsr(J: torch.Tensor) -> torch.Tensor:
145148
return diag_blocks.flatten()
146149

147150

151+
def _assert_coo_no_empty_columns(J: torch.Tensor) -> None:
152+
assert J.layout == torch.sparse_coo
153+
J = J.coalesce()
154+
n_cols = int(J.shape[1])
155+
if n_cols == 0:
156+
return
157+
cols = J.indices()[1].to(torch.int64)
158+
counts = torch.bincount(cols, minlength=n_cols)
159+
assert (counts > 0).all()
160+
161+
148162
def _assert_bal_correctness_criteria(
149163
J_cam: torch.Tensor,
150164
J_pts: torch.Tensor,
@@ -233,6 +247,7 @@ def test_bal_jacobian_structure_no_empty_columns(
233247

234248
model = Reproj(camera_params.clone(), points_3d.clone()).to(device)
235249
residual = model(points_2d, camera_idx, point_idx)
250+
n_obs = int(points_2d.shape[0])
236251

237252
J_cam, J_pts = autograd_graph.jacobian(residual, [model.pose, model.points_3d])
238253
assert J_cam.layout == torch.sparse_bsr
@@ -241,6 +256,9 @@ def test_bal_jacobian_structure_no_empty_columns(
241256
n_cams = model.pose.shape[0]
242257
n_pts = model.points_3d.shape[0]
243258

259+
assert J_cam.shape == (n_obs * 2, n_cams * 9)
260+
assert J_pts.shape == (n_obs * 2, n_pts * 3)
261+
244262
_assert_bal_correctness_criteria(
245263
J_cam,
246264
J_pts,
@@ -250,6 +268,92 @@ def test_bal_jacobian_structure_no_empty_columns(
250268
n_pts=n_pts,
251269
)
252270

271+
J_full = torch.cat([t.to_sparse_coo() for t in (J_cam, J_pts)], dim=-1)
272+
_assert_coo_no_empty_columns(J_full)
273+
274+
275+
276+
@map_transform
277+
def transform_points(points, se3_params):
278+
return pp.SE3(se3_params).Act(points)
279+
280+
281+
class ReprojCat(nn.Module):
282+
def __init__(self, camera_params, points_b, points_c, se3_c):
283+
super().__init__()
284+
self.pose = nn.Parameter(TrackingTensor(camera_params))
285+
self.points_b = nn.Parameter(TrackingTensor(points_b))
286+
self.points_c = nn.Parameter(TrackingTensor(points_c))
287+
self.se3_c = nn.Parameter(TrackingTensor(se3_c))
288+
self.pose.trim_SE3_grad = True
289+
self.se3_c.trim_SE3_grad = True
290+
291+
def forward(self, points_2d, camera_indices, point_indices):
292+
points_c = transform_points(self.points_c, self.se3_c)
293+
points_all = torch.cat([self.points_b, points_c], dim=0)
294+
points_proj = project(points_all[point_indices], self.pose[camera_indices])
295+
return points_proj - points_2d
296+
297+
298+
@pytest.mark.parametrize(
299+
("dataset", "problem_name"),
300+
_BAL_SAMPLES,
301+
ids=[f"{ds}.{name}" for ds, name in _BAL_SAMPLES],
302+
)
303+
def test_bal_jacobian_cat_split_points_no_empty_columns(
304+
dataset: str,
305+
problem_name: str,
306+
bal_cache_dir: Path,
307+
):
308+
data = _load_bal_problem(dataset, problem_name, bal_cache_dir)
309+
310+
device = torch.device("cpu")
311+
dtype = torch.float64
312+
313+
camera_params = data["camera_params"].to(device=device, dtype=dtype)
314+
points_3d = data["points_3d"].to(device=device, dtype=dtype)
315+
points_2d = data["points_2d"].to(device=device, dtype=dtype)
316+
camera_idx = data["camera_index_of_observations"].to(torch.int32).to(device=device)
317+
point_idx = data["point_index_of_observations"].to(torch.int32).to(device=device)
318+
319+
n_pts = int(points_3d.shape[0])
320+
split = max(1, n_pts // 2)
321+
if split >= n_pts:
322+
pytest.skip("BAL sample has <2 points; cannot construct cat split case.")
323+
324+
points_b = points_3d[:split].clone()
325+
points_c = points_3d[split:].clone()
326+
327+
torch.manual_seed(0)
328+
se3_c = pp.randn_SE3(points_c.shape[0], device=device, dtype=dtype).tensor()
329+
330+
model = ReprojCat(camera_params.clone(), points_b, points_c, se3_c).to(device)
331+
residual = model(points_2d, camera_idx, point_idx)
332+
n_obs = int(points_2d.shape[0])
333+
334+
J_cam, J_b, J_c, J_se3 = autograd_graph.jacobian(
335+
residual,
336+
[model.pose, model.points_b, model.points_c, model.se3_c],
337+
)
338+
339+
n_cams = model.pose.shape[0]
340+
n_b = model.points_b.shape[0]
341+
n_c = model.points_c.shape[0]
342+
343+
assert J_cam.shape == (n_obs * 2, n_cams * 9)
344+
assert J_b.shape == (n_obs * 2, n_b * 3)
345+
assert J_c.shape == (n_obs * 2, n_c * 3)
346+
assert J_se3.shape == (n_obs * 2, n_c * 6)
347+
348+
J_full = torch.cat(
349+
[t.to_sparse_coo() for t in (J_cam, J_b, J_c, J_se3)],
350+
dim=-1,
351+
).coalesce()
352+
_assert_coo_no_empty_columns(J_full)
353+
diag = torch.zeros(J_full.shape[1], dtype=J_full.dtype, device=J_full.device)
354+
diag.scatter_add_(0, J_full.indices()[1].to(torch.int64), J_full.values().square())
355+
assert (diag > 0).all()
356+
253357

254358
@pytest.mark.parametrize(
255359
("dataset", "problem_name"),

tests/autograd/test_graph_jacobian.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,44 @@ def f(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
232232
assert JB_sparse.crow_indices()[n_a].item() == 0
233233
assert JB_sparse.crow_indices()[-1].item() == n_b
234234
assert torch.equal(JB_sparse.col_indices(), idx_b)
235+
236+
237+
class CatIndexResidual(nn.Module):
238+
def __init__(self, A: torch.Tensor, B: torch.Tensor):
239+
super().__init__()
240+
self.A = nn.Parameter(Track(A))
241+
self.B = nn.Parameter(Track(B))
242+
243+
def forward(self, obs: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
244+
cat = torch.cat([self.A, self.B], dim=0)
245+
return cat[idx] - obs
246+
247+
248+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
249+
def test_sparse_jacobian_index_after_cat_matches_torch_jacrev(device: str):
250+
if device == "cuda" and not torch.cuda.is_available():
251+
pytest.skip("CUDA not available")
252+
253+
torch.manual_seed(0)
254+
dtype = torch.float64
255+
256+
num_a, num_b = 4, 6
257+
dim = 3
258+
n = 9
259+
260+
A0 = torch.randn(num_a, dim, device=device, dtype=dtype)
261+
B0 = torch.randn(num_b, dim, device=device, dtype=dtype)
262+
obs = torch.randn(n, dim, device=device, dtype=dtype)
263+
idx = torch.randint(0, num_a + num_b, (n,), device=device, dtype=torch.int32)
264+
265+
model = CatIndexResidual(A0, B0)
266+
out = model(obs, idx)
267+
JA_sparse, JB_sparse = sparse_jacobian(out, [model.A, model.B])
268+
269+
def f(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
270+
cat = torch.cat([A, B], dim=0)
271+
return cat[idx] - obs
272+
273+
JA, JB = jacrev(f, argnums=(0, 1))(A0, B0)
274+
torch.testing.assert_close(JA_sparse.to_dense(), _flatten_jac(JA), rtol=1e-10, atol=1e-10)
275+
torch.testing.assert_close(JB_sparse.to_dense(), _flatten_jac(JB), rtol=1e-10, atol=1e-10)

0 commit comments

Comments
 (0)