diff --git a/.github/workflows/test-graph-jacobian.yml b/.github/workflows/test-graph-jacobian.yml index 78e4a73..2287437 100644 --- a/.github/workflows/test-graph-jacobian.yml +++ b/.github/workflows/test-graph-jacobian.yml @@ -36,4 +36,4 @@ jobs: PYTHONPATH: /tmp/pydeps:${{ github.workspace }} run: | python -c "import torch; print('torch', torch.__version__)" - python -m pytest -q tests/autograd/test_graph_jacobian.py + python -m pytest -q tests/autograd diff --git a/datapipes/bal_io.py b/datapipes/bal_io.py new file mode 100644 index 0000000..3a931cf --- /dev/null +++ b/datapipes/bal_io.py @@ -0,0 +1,73 @@ +import os + +import torch + +DTYPE = torch.float64 + + +def _rotvec_to_quat_xyzw(rotvec: torch.Tensor) -> torch.Tensor: + theta = torch.linalg.norm(rotvec, dim=-1, keepdim=True) + half_theta = 0.5 * theta + cos_half = torch.cos(half_theta) + sin_half = torch.sin(half_theta) + + eps = 1e-12 + scale = torch.where(theta > eps, sin_half / theta, 0.5 - (theta * theta) / 48.0) + xyz = rotvec * scale + return torch.cat([xyz, cos_half], dim=-1) + + +def read_bal_data(file_name: str, use_quat: bool = False) -> dict: + """ + Read a Bundle Adjustment in the Large dataset problem text file. + + Format: + + (repeated n_observations) + (n_cameras * 9 lines) + (n_points * 3 lines) + + Each camera has 9 parameters: Rodrigues rotvec (3), translation (3), f, k1, k2. + This loader outputs either: + - use_quat=False: [tx, ty, tz, rx, ry, rz, f, k1, k2] (9) + - use_quat=True: [tx, ty, tz, qx, qy, qz, qw, f, k1, k2] (10) + """ + with open(file_name, "r") as file: + n_cameras, n_points, n_observations = map(int, file.readline().split()) + + camera_indices = torch.empty(n_observations, dtype=torch.int64) + point_indices = torch.empty(n_observations, dtype=torch.int64) + points_2d = torch.empty((n_observations, 2), dtype=DTYPE) + + for i in range(n_observations): + camera_index, point_index, x, y = file.readline().split() + camera_indices[i] = int(camera_index) + point_indices[i] = int(point_index) + points_2d[i, 0] = float(x) + points_2d[i, 1] = float(y) + + camera_params = torch.empty(n_cameras * 9, dtype=DTYPE) + for i in range(n_cameras * 9): + camera_params[i] = float(file.readline()) + camera_params = camera_params.reshape((n_cameras, 9)) + + points_3d = torch.empty(n_points * 3, dtype=DTYPE) + for i in range(n_points * 3): + points_3d[i] = float(file.readline()) + points_3d = points_3d.reshape((n_points, 3)) + + if use_quat: + q = _rotvec_to_quat_xyzw(camera_params[:, :3]) + camera_params = torch.cat([camera_params[:, 3:6], q, camera_params[:, 6:]], dim=1) + else: + camera_params = torch.cat([camera_params[:, 3:6], camera_params[:, :3], camera_params[:, 6:]], dim=1) + + return { + "problem_name": os.path.splitext(os.path.basename(file_name))[0], + "camera_params": camera_params.to(DTYPE), + "points_3d": points_3d, + "points_2d": points_2d, + "camera_index_of_observations": camera_indices, + "point_index_of_observations": point_indices, + } + diff --git a/datapipes/bal_loader.py b/datapipes/bal_loader.py index 0d9d332..7a706b3 100644 --- a/datapipes/bal_loader.py +++ b/datapipes/bal_loader.py @@ -9,20 +9,24 @@ Link to the dataset: https://grail.cs.washington.edu/projects/bal/ """ -import torch, os, warnings -import numpy as np +import os +import warnings + +import torch from functools import partial -from operator import itemgetter, methodcaller -from bs4 import BeautifulSoup, MarkupResemblesLocatorWarning -from torchvision.transforms import Compose -from scipy.spatial.transform import Rotation -from torchdata.datapipes.iter import HttpReader, IterableWrapper, FileOpener -import pypose as pp +from operator import methodcaller -DTYPE = torch.float64 +from .bal_io import DTYPE, read_bal_data -# ignore bs4 warning -warnings.filterwarnings("ignore", category=MarkupResemblesLocatorWarning) +def _torchdata(): + try: + from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper + except ImportError as e: + raise ImportError( + "torchdata is required for datapipes.bal_loader streaming utilities. " + "If you only need parsing, import read_bal_data from datapipes.bal_io." + ) from e + return HttpReader, IterableWrapper, FileOpener # only export __all__ __ALL__ = ['build_pipeline', 'read_bal_data', 'DATA_URL', 'ALL_DATASETS'] @@ -46,8 +50,22 @@ def _not_none(s): # extract problem file urls from the problem url def _problem_lister(*problem_url, cache_dir): + HttpReader, IterableWrapper, FileOpener = _torchdata() + try: + from bs4 import BeautifulSoup, MarkupResemblesLocatorWarning + except ImportError as e: + raise ImportError( + "bs4 is required for datapipes.bal_loader streaming utilities. " + "If you only need parsing, import read_bal_data from datapipes.bal_io." + ) from e + + warnings.filterwarnings("ignore", category=MarkupResemblesLocatorWarning) + + def _cache_path(url: str) -> str: + return os.path.join(cache_dir, os.path.basename(url)) + problem_list_dp = IterableWrapper(problem_url).on_disk_cache( - filepath_fn=Compose([os.path.basename, partial(os.path.join, cache_dir)]), + filepath_fn=_cache_path, ) problem_list_dp = HttpReader(problem_list_dp).end_caching(same_filepath_fn=True) @@ -69,113 +87,28 @@ def _problem_lister(*problem_url, cache_dir): # download and decompress the problem files def _download_pipe(cache_dir, url_dp, suffix: str): + HttpReader, _, _ = _torchdata() + + def _cache_path(url: str) -> str: + return os.path.join(cache_dir, os.path.basename(url)) + + def _strip_suffix(path: str) -> str: + return path.split(suffix)[0] + # cache compressed files cache_compressed = url_dp.on_disk_cache( - filepath_fn=Compose([os.path.basename, partial(os.path.join, cache_dir)]) , + filepath_fn=_cache_path, ) cache_compressed = HttpReader(cache_compressed).end_caching(same_filepath_fn=True) # cache decompressed files cache_decompressed = cache_compressed.on_disk_cache( - filepath_fn=Compose([partial(str.split, sep=suffix), itemgetter(0)]), + filepath_fn=_strip_suffix, ) cache_decompressed = cache_decompressed.open_files(mode="b").load_from_bz2().end_caching( same_filepath_fn=True ) return cache_decompressed -def read_bal_data(file_name: str, use_quat=False) -> dict: - """ - Read a Bundle Adjustment in the Large dataset. - - Referenced Scipy's BAL loader: https://scipy-cookbook.readthedocs.io/items/bundle_adjustment.html - - According to BAL official documentation, each problem is provided as a text file in the following format: - - - - ... - - - ... - - - ... - - - Where, there camera and point indices start from 0. Each camera is a set of 9 parameters - R,t,f,k1 and k2. The rotation R is specified as a Rodrigues' vector. - - Parameters - ---------- - file_name : str - The decompressed file of the dataset. - - Returns - ------- - dict - A dictionary containing the following fields: - - problem_name: str - The name of the problem. - - camera_params: torch.Tensor (n_cameras, 9 or 10) - contains camera parameters for each camera. If use_quat is True, the shape is (n_cameras, 10). - - points_3d: torch.Tensor (n_points, 3) - contains initial estimates of point coordinates in the world frame. - - points_2d: torch.Tensor (n_observations, 2) - contains measured 2-D coordinates of points projected on images in each observations. - - camera_index_of_observations: torch.Tensor (n_observations,) - contains indices of cameras (from 0 to n_cameras - 1) involved in each observation. - - point_index_of_observations: torch.Tensor (n_observations,) - contains indices of points (from 0 to n_points - 1) involved in each observation. - """ - with open(file_name, "r") as file: - n_cameras, n_points, n_observations = map( - int, file.readline().split()) - - camera_indices = torch.empty(n_observations, dtype=torch.int64) - point_indices = torch.empty(n_observations, dtype=torch.int64) - points_2d = torch.empty((n_observations, 2), dtype=DTYPE) - - for i in range(n_observations): - tmp_line = file.readline() - camera_index, point_index, x, y = tmp_line.split() - camera_indices[i] = int(camera_index) - point_indices[i] = int(point_index) - points_2d[i, 0] = float(x) - points_2d[i, 1] = float(y) - - camera_params = torch.empty(n_cameras * 9, dtype=DTYPE) - for i in range(n_cameras * 9): - camera_params[i] = float(file.readline()) - camera_params = camera_params.reshape((n_cameras, -1)) - - points_3d = torch.empty(n_points * 3, dtype=DTYPE) - for i in range(n_points * 3): - points_3d[i] = float(file.readline()) - points_3d = points_3d.reshape((n_points, -1)) - - if use_quat: - # convert Rodrigues vector to unit quaternion for camera rotation - # camera_params[0:3] is the Rodrigues vector - # after conversion, camera_params[0:4] is the unit quaternion - # r = Rotation.from_rotvec(camera_params[:, :3]) - # q = r.as_quat() - r = pp.so3(camera_params[:, :3]) - q = r.Exp() - # [tx, ty, tz, q0, q1, q2, q3, f, k1, k2] - camera_params = torch.cat([camera_params[:, 3:6], q, camera_params[:, 6:]], axis=1) - else: - camera_params = torch.cat([camera_params[:, 3:6], camera_params[:, :3], camera_params[:, 6:]], axis=1) - - # convert camera_params to torch.Tensor - camera_params = torch.tensor(camera_params).to(DTYPE) - - return {'problem_name': os.path.splitext(os.path.basename(file_name))[0], # str - 'camera_params': camera_params, # torch.Tensor (n_cameras, 9 or 10) - 'points_3d': points_3d, # torch.Tensor (n_points, 3) - 'points_2d': points_2d, # torch.Tensor (n_observations, 2) - 'camera_index_of_observations': camera_indices, # torch.Tensor (n_observations,) - 'point_index_of_observations': point_indices, # torch.Tensor (n_observations,) - } - def build_pipeline(dataset='ladybug', cache_dir='bal_data', use_quat=False): """ Build a pipeline for the Bundle Adjustment in the Large dataset. diff --git a/tests/autograd/test_bal_jacobian.py b/tests/autograd/test_bal_jacobian.py new file mode 100644 index 0000000..d03a34d --- /dev/null +++ b/tests/autograd/test_bal_jacobian.py @@ -0,0 +1,297 @@ +from __future__ import annotations + +from pathlib import Path +import bz2 +import os +import shutil +import sys +import urllib.error +import urllib.request +from urllib.parse import urljoin + +import pytest +import torch + +_REPO_ROOT = Path(__file__).resolve().parents[2] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from ba_helpers import Reproj # noqa: E402 +import bae.autograd.graph as autograd_graph # noqa: E402 +from datapipes.bal_io import read_bal_data # noqa: E402 + + +pytestmark = [ + pytest.mark.filterwarnings(r"ignore:CUDA initialization.*:UserWarning"), + pytest.mark.filterwarnings(r"ignore:Sparse BSR tensor support is in beta state.*:UserWarning"), +] + +_BAL_DATA_DIR = _REPO_ROOT / "bal_data" +_BAL_DATA_URL = "https://grail.cs.washington.edu/projects/bal/" +_BAL_SAMPLES: list[tuple[str, str]] = [ + ("trafalgar", "problem-257-65132-pre"), + ("dubrovnik", "problem-356-226730-pre"), + ("ladybug", "problem-1723-156502-pre"), +] + + +def _candidate_bal_urls(dataset: str, bz2_name: str) -> list[str]: + base = _BAL_DATA_URL if _BAL_DATA_URL.endswith("/") else (_BAL_DATA_URL + "/") + prefixes = [ + f"data/{dataset}/", # matches BAL html link format + f"{dataset}/", + f"bal/data/{dataset}/", + "data/", + "bal/data/", + "", + ] + urls: list[str] = [] + seen: set[str] = set() + for prefix in prefixes: + url = urljoin(base, prefix + bz2_name) + if url not in seen: + seen.add(url) + urls.append(url) + return urls + + +def _download_url(url: str, dst_path: Path, *, timeout_s: float = 60.0) -> None: + dst_path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = dst_path.with_suffix(dst_path.suffix + ".tmp") + req = urllib.request.Request(url, headers={"User-Agent": "bae-pytest/1.0"}) + try: + with urllib.request.urlopen(req, timeout=timeout_s) as resp, tmp_path.open("wb") as f: + shutil.copyfileobj(resp, f) + except Exception: + try: + tmp_path.unlink(missing_ok=True) + except Exception: + pass + raise + os.replace(tmp_path, dst_path) + + +def _ensure_bal_problem_downloaded(dataset: str, problem_name: str, cache_dir: Path) -> Path: + problem_name = problem_name.removesuffix(".txt").removesuffix(".bz2").removesuffix(".txt") + txt_path = cache_dir / f"{problem_name}.txt" + bz2_path = cache_dir / f"{problem_name}.txt.bz2" + + if txt_path.exists() and txt_path.stat().st_size > 0: + return txt_path + + if not bz2_path.exists() or bz2_path.stat().st_size == 0: + bz2_name = bz2_path.name + last_err: BaseException | None = None + for url in _candidate_bal_urls(dataset, bz2_name): + try: + _download_url(url, bz2_path) + last_err = None + break + except urllib.error.URLError as e: + last_err = e + if last_err is not None: + raise last_err + + tmp_txt = txt_path.with_suffix(".txt.tmp") + try: + with bz2.open(bz2_path, "rb") as src, tmp_txt.open("wb") as dst: + shutil.copyfileobj(src, dst) + except Exception: + try: + tmp_txt.unlink(missing_ok=True) + except Exception: + pass + raise + os.replace(tmp_txt, txt_path) + return txt_path + + +@pytest.fixture(scope="session") +def bal_cache_dir(tmp_path_factory: pytest.TempPathFactory) -> Path: + override = os.environ.get("BAE_BAL_CACHE_DIR") + if override: + return Path(override).expanduser().resolve() + + # Prefer the repository's `bal_data/` if it already contains the samples + # (keeps local development fully offline). + if _BAL_DATA_DIR.exists() and all((_BAL_DATA_DIR / f"{name}.txt").exists() for _, name in _BAL_SAMPLES): + return _BAL_DATA_DIR + + return Path(tmp_path_factory.mktemp("bal_data")) + + +def _load_bal_problem(dataset: str, problem_name: str, cache_dir: Path) -> dict: + # Allow fully-offline runs when *some* BAL samples are already present in the + # repository's `bal_data/`, even if others are missing. + normalized = problem_name.removesuffix(".txt").removesuffix(".bz2").removesuffix(".txt") + local_txt = _BAL_DATA_DIR / f"{normalized}.txt" + if local_txt.exists() and local_txt.stat().st_size > 0: + return read_bal_data(str(local_txt), use_quat=True) + + try: + path = _ensure_bal_problem_downloaded(dataset, problem_name, cache_dir) + except Exception as e: + pytest.skip(f"Could not download BAL sample {dataset}/{problem_name}: {e!r}") + return read_bal_data(str(path), use_quat=True) + + +def _jtj_diag_from_bsr(J: torch.Tensor) -> torch.Tensor: + values = J.values() # (nnz_blocks, block_rows, block_cols) + contrib = (values * values).sum(dim=-2) # (nnz_blocks, block_cols) + col_blocks = J.col_indices().to(torch.int64) + num_blocks = J.shape[1] // values.shape[-1] + diag_blocks = torch.zeros((num_blocks, values.shape[-1]), dtype=values.dtype, device=values.device) + diag_blocks.index_add_(0, col_blocks, contrib) + return diag_blocks.flatten() + + +def _assert_bal_correctness_criteria( + J_cam: torch.Tensor, + J_pts: torch.Tensor, + *, + camera_idx: torch.Tensor, + point_idx: torch.Tensor, + n_cams: int, + n_pts: int, +) -> None: + # Correctness criterion 1: no empty block-columns in each BSR Jacobian. + assert torch.equal(J_cam.col_indices(), camera_idx) + assert torch.equal(J_pts.col_indices(), point_idx) + assert torch.unique(J_cam.col_indices()).numel() == n_cams + assert torch.unique(J_pts.col_indices()).numel() == n_pts + + # Correctness criterion 2: after concatenation, diag(J^T J) is fully occupied. + diag = torch.cat([_jtj_diag_from_bsr(J_cam), _jtj_diag_from_bsr(J_pts)], dim=0) + assert (diag > 0).all() + + +def _remove_camera_and_or_point_appearance( + camera_idx: torch.Tensor, + point_idx: torch.Tensor, + *, + n_cams: int, + n_pts: int, +) -> tuple[torch.Tensor, torch.Tensor]: + # Mutate observation index arrays so at least one camera and/or point ID + # disappears entirely from observations, creating empty block-columns. + remove_camera = bool(torch.randint(0, 2, (1,)).item()) + remove_point = bool(torch.randint(0, 2, (1,)).item()) + if not remove_camera and not remove_point: + remove_camera = True + + camera_idx2 = camera_idx.clone() + point_idx2 = point_idx.clone() + + if remove_camera: + if n_cams < 2: + pytest.skip("BAL sample has <2 cameras; cannot construct removal case.") + cam_remove = int(torch.randint(0, n_cams, (1,)).item()) + cam_offset = int(torch.randint(1, n_cams, (1,)).item()) + cam_repl = (cam_remove + cam_offset) % n_cams + camera_idx2[camera_idx2 == cam_remove] = cam_repl + assert (camera_idx2 == cam_remove).sum().item() == 0 + + if remove_point: + if n_pts < 2: + pytest.skip("BAL sample has <2 points; cannot construct removal case.") + pt_remove = int(torch.randint(0, n_pts, (1,)).item()) + pt_offset = int(torch.randint(1, n_pts, (1,)).item()) + pt_repl = (pt_remove + pt_offset) % n_pts + point_idx2[point_idx2 == pt_remove] = pt_repl + assert (point_idx2 == pt_remove).sum().item() == 0 + + return camera_idx2, point_idx2 + + +@pytest.mark.parametrize( + ("dataset", "problem_name"), + _BAL_SAMPLES, + ids=[f"{ds}.{name}" for ds, name in _BAL_SAMPLES], +) +def test_bal_jacobian_structure_no_empty_columns( + dataset: str, + problem_name: str, + bal_cache_dir: Path, +): + data = _load_bal_problem(dataset, problem_name, bal_cache_dir) + + # CPU-only: CI doesn't have CUDA. + device = torch.device("cpu") + dtype = torch.float64 + + camera_params = data["camera_params"] + points_3d = data["points_3d"] + points_2d = data["points_2d"] + camera_idx = data["camera_index_of_observations"].to(torch.int32) + point_idx = data["point_index_of_observations"].to(torch.int32) + + camera_params = camera_params.to(device=device, dtype=dtype) + points_3d = points_3d.to(device=device, dtype=dtype) + points_2d = points_2d.to(device=device, dtype=dtype) + camera_idx = camera_idx.to(device=device) + point_idx = point_idx.to(device=device) + + model = Reproj(camera_params.clone(), points_3d.clone()).to(device) + residual = model(points_2d, camera_idx, point_idx) + + J_cam, J_pts = autograd_graph.jacobian(residual, [model.pose, model.points_3d]) + assert J_cam.layout == torch.sparse_bsr + assert J_pts.layout == torch.sparse_bsr + + n_cams = model.pose.shape[0] + n_pts = model.points_3d.shape[0] + + _assert_bal_correctness_criteria( + J_cam, + J_pts, + camera_idx=camera_idx, + point_idx=point_idx, + n_cams=n_cams, + n_pts=n_pts, + ) + + +@pytest.mark.parametrize( + ("dataset", "problem_name"), + _BAL_SAMPLES, + ids=[f"{ds}.{name}" for ds, name in _BAL_SAMPLES], +) +def test_bal_jacobian_structure_assert_failed_when_missing_observation_appearance( + dataset: str, + problem_name: str, + bal_cache_dir: Path, +): + data = _load_bal_problem(dataset, problem_name, bal_cache_dir) + + device = torch.device("cpu") + dtype = torch.float64 + + camera_params = data["camera_params"].to(device=device, dtype=dtype) + points_3d = data["points_3d"].to(device=device, dtype=dtype) + points_2d = data["points_2d"].to(device=device, dtype=dtype) + camera_idx = data["camera_index_of_observations"].to(torch.int32).to(device=device) + point_idx = data["point_index_of_observations"].to(torch.int32).to(device=device) + + n_cams = int(camera_params.shape[0]) + n_pts = int(points_3d.shape[0]) + + camera_idx2, point_idx2 = _remove_camera_and_or_point_appearance( + camera_idx, + point_idx, + n_cams=n_cams, + n_pts=n_pts, + ) + + model = Reproj(camera_params.clone(), points_3d.clone()).to(device) + residual = model(points_2d, camera_idx2, point_idx2) + J_cam, J_pts = autograd_graph.jacobian(residual, [model.pose, model.points_3d]) + + with pytest.raises(AssertionError): + _assert_bal_correctness_criteria( + J_cam, + J_pts, + camera_idx=camera_idx2, + point_idx=point_idx2, + n_cams=n_cams, + n_pts=n_pts, + )