Skip to content

Commit 87b8a9c

Browse files
author
Yue Pan
committed
[MINOR] more documents
1 parent 77be121 commit 87b8a9c

5 files changed

Lines changed: 111 additions & 25 deletions

File tree

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,10 +352,8 @@ After building the container, configure the storage path in `start_docker.sh` an
352352
```
353353
sudo chmod +x ./start_docker.sh
354354
./start_docker.sh
355-
356355
```
357356

358-
359357
## Visualizer Instructions
360358

361359
We provide a PIN-SLAM visualizer based on [lidar-visualizer](https://github.com/PRBonn/lidar-visualizer) to monitor the SLAM process. You can use `-v` flag to turn on it.

utils/loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def smooth_sdf_loss(pred, label, delta=20.0, weight=None, weighted=False):
100100
final_loss = ((2.0 / delta) * final_loss * weight).mean()
101101
return final_loss
102102

103-
103+
# deprecated
104104
def ray_estimation_loss(x, y, d_meas): # for each ray
105105
# x as depth
106106
# y as sdf prediction
@@ -120,7 +120,7 @@ def ray_estimation_loss(x, y, d_meas): # for each ray
120120

121121
return d_error
122122

123-
123+
# deprecated
124124
def ray_rendering_loss(x, y, d_meas): # for each ray [should run in batch]
125125
# x as depth
126126
# y as occ.prob. prediction
@@ -140,7 +140,7 @@ def ray_rendering_loss(x, y, d_meas): # for each ray [should run in batch]
140140

141141
return d_error
142142

143-
143+
# deprecated
144144
def batch_ray_rendering_loss(x, y, d_meas, neus_on=True): # for all rays in a batch
145145
# x as depth [ray number * sample number]
146146
# y as prediction (the alpha in volume rendering) [ray number * sample number]

utils/mesher.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,9 @@ def get_query_from_bbx(self, bbx, voxel_size, pad_voxel=0, skip_top_voxel=0):
213213
return coord, voxel_num_xyz, voxel_origin
214214

215215
def get_query_from_hor_slice(self, bbx, slice_z, voxel_size):
216-
"""get grid query points inside a given bounding box (bbx) at slice height (slice_z)"""
216+
"""
217+
get grid query points inside a given bounding box (bbx) at slice height (slice_z)
218+
"""
217219
# bbx and voxel_size are all in the world coordinate system
218220
min_bound = bbx.get_min_bound()
219221
max_bound = bbx.get_max_bound()
@@ -246,7 +248,9 @@ def get_query_from_hor_slice(self, bbx, slice_z, voxel_size):
246248
return coord, voxel_num_xyz, voxel_origin
247249

248250
def get_query_from_ver_slice(self, bbx, slice_x, voxel_size):
249-
"""get grid query points inside a given bounding box (bbx) at slice position (slice_x)"""
251+
"""
252+
get grid query points inside a given bounding box (bbx) at slice position (slice_x)
253+
"""
250254
# bbx and voxel_size are all in the world coordinate system
251255
min_bound = bbx.get_min_bound()
252256
max_bound = bbx.get_max_bound()
@@ -279,6 +283,9 @@ def get_query_from_ver_slice(self, bbx, slice_x, voxel_size):
279283
return coord, voxel_num_xyz, voxel_origin
280284

281285
def generate_sdf_map(self, coord, sdf_pred, mc_mask):
286+
"""
287+
Generate the SDF map for saving
288+
"""
282289
device = o3d.core.Device("CPU:0")
283290
dtype = o3d.core.float32
284291
sdf_map_pc = o3d.t.geometry.PointCloud(device)
@@ -305,7 +312,9 @@ def generate_sdf_map(self, coord, sdf_pred, mc_mask):
305312
def generate_sdf_map_for_vis(
306313
self, coord, sdf_pred, mc_mask, min_sdf=-1.0, max_sdf=1.0, cmap="bwr"
307314
): # 'jet','bwr','viridis'
308-
315+
"""
316+
Generate the SDF map for visualization
317+
"""
309318
# do the masking or not
310319
if mc_mask is not None:
311320
coord = coord[mc_mask > 0]
@@ -392,6 +401,9 @@ def mc_mesh(self, mc_sdf, mc_mask, voxel_size, mc_origin):
392401
return verts, faces
393402

394403
def estimate_vertices_sem(self, mesh, verts, filter_free_space_vertices=True):
404+
"""
405+
Predict the semantic label of the vertices
406+
"""
395407
if len(verts) == 0:
396408
return mesh
397409

@@ -413,6 +425,9 @@ def estimate_vertices_sem(self, mesh, verts, filter_free_space_vertices=True):
413425
return mesh
414426

415427
def estimate_vertices_color(self, mesh, verts):
428+
"""
429+
Predict the color of the vertices
430+
"""
416431
if len(verts) == 0:
417432
return mesh
418433

@@ -430,7 +445,9 @@ def estimate_vertices_color(self, mesh, verts):
430445
return mesh
431446

432447
def filter_isolated_vertices(self, mesh, filter_cluster_min_tri=300):
433-
# print("Cluster connected triangles")
448+
"""
449+
Cluster connected triangles and remove the small clusters
450+
"""
434451
triangle_clusters, cluster_n_triangles, _ = mesh.cluster_connected_triangles()
435452
triangle_clusters = np.asarray(triangle_clusters)
436453
cluster_n_triangles = np.asarray(cluster_n_triangles)
@@ -445,6 +462,9 @@ def filter_isolated_vertices(self, mesh, filter_cluster_min_tri=300):
445462
def generate_bbx_sdf_hor_slice(
446463
self, bbx, slice_z, voxel_size, query_locally=False, min_sdf=-1.0, max_sdf=1.0
447464
):
465+
"""
466+
Generate the SDF slice at height (slice_z)
467+
"""
448468
# print("Generate the SDF slice at heright %.2f (m)" % (slice_z))
449469
coord, _, _ = self.get_query_from_hor_slice(bbx, slice_z, voxel_size)
450470
sdf_pred, _, _, mc_mask = self.query_points(
@@ -466,6 +486,9 @@ def generate_bbx_sdf_hor_slice(
466486
def generate_bbx_sdf_ver_slice(
467487
self, bbx, slice_x, voxel_size, query_locally=False, min_sdf=-1.0, max_sdf=1.0
468488
):
489+
"""
490+
Generate the SDF slice at x position (slice_x)
491+
"""
469492
# print("Generate the SDF slice at x position %.2f (m)" % (slice_x))
470493
coord, _, _ = self.get_query_from_ver_slice(bbx, slice_x, voxel_size)
471494
sdf_pred, _, _, mc_mask = self.query_points(
@@ -499,6 +522,9 @@ def recon_aabb_collections_mesh(
499522
mesh_min_nn=10,
500523
use_torch_mc=False,
501524
):
525+
"""
526+
Reconstruct the mesh from a collection of bounding boxes
527+
"""
502528
if not self.silence:
503529
print("# Chunk for meshing: ", len(aabbs))
504530

@@ -545,7 +571,9 @@ def recon_aabb_mesh(
545571
mesh_min_nn=10,
546572
use_torch_mc=False,
547573
):
548-
574+
"""
575+
Reconstruct the mesh from a given bounding box
576+
"""
549577
# reconstruct and save the (semantic) mesh from the feature octree the decoders within a
550578
# given bounding box. bbx and voxel_size all with unit m, in world coordinate system
551579
coord, voxel_num_xyz, voxel_origin = self.get_query_from_bbx(

utils/tools.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,10 @@ def step_lr_decay(
227227
return learning_rate
228228

229229

230-
# calculate the analytical gradient by pytorch auto diff
231230
def get_gradient(inputs, outputs):
231+
"""
232+
Calculate the analytical gradient by pytorch auto diff
233+
"""
232234
d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
233235
points_grad = grad(
234236
outputs=outputs,
@@ -387,8 +389,10 @@ def create_axis_aligned_bounding_box(center, size):
387389

388390

389391
def apply_quaternion_rotation(quat: torch.tensor, points: torch.tensor) -> torch.tensor:
390-
# apply passive rotation: coordinate system rotation w.r.t. the points
391-
# p' = qpq^-1
392+
"""
393+
Apply passive rotation: coordinate system rotation w.r.t. the points
394+
p' = qpq^-1
395+
"""
392396
quat_w = quat[..., 0].unsqueeze(-1)
393397
quat_xyz = -quat[..., 1:]
394398
t = 2 * torch.linalg.cross(quat_xyz, points)
@@ -416,6 +420,11 @@ def rotmat_to_quat(rot_matrix: torch.tensor):
416420

417421

418422
def quat_to_rotmat(quaternions: torch.tensor):
423+
"""
424+
Convert a batch of quaternions to rotation matrices.
425+
quaternions: N,4
426+
return N,3,3
427+
"""
419428
# Ensure quaternions are normalized
420429
quaternions /= torch.norm(quaternions, dim=1, keepdim=True)
421430

@@ -469,19 +478,32 @@ def quat_multiply(q1: torch.tensor, q2: torch.tensor):
469478

470479

471480
def torch2o3d(points_torch):
481+
"""
482+
Convert a batch of points from torch to o3d
483+
"""
472484
pc_o3d = o3d.geometry.PointCloud()
473485
points_np = points_torch.cpu().detach().numpy().astype(np.float64)
474486
pc_o3d.points = o3d.utility.Vector3dVector(points_np)
475487
return pc_o3d
476488

477489

478490
def o3d2torch(o3d, device="cpu", dtype=torch.float32):
491+
"""
492+
Convert a batch of points from o3d to torch
493+
"""
479494
return torch.tensor(np.asarray(o3d.points), dtype=dtype, device=device)
480495

481496

482497
def transform_torch(points: torch.tensor, transformation: torch.tensor):
483-
# points [N, 3]
484-
# transformation [4, 4]
498+
"""
499+
Transform a batch of points by a transformation matrix
500+
Args:
501+
points: N,3 torch tensor, the coordinates of all N (axbxc) query points in the scaled
502+
kaolin coordinate system [-1,1]
503+
transformation: 4,4 torch tensor, the transformation matrix
504+
Returns:
505+
transformed_points: N,3 torch tensor, the transformed coordinates
506+
"""
485507
# Add a homogeneous coordinate to each point in the point cloud
486508
points_homo = torch.cat([points, torch.ones(points.shape[0], 1).to(points)], dim=1)
487509

@@ -495,9 +517,15 @@ def transform_torch(points: torch.tensor, transformation: torch.tensor):
495517

496518

497519
def transform_batch_torch(points: torch.tensor, transformation: torch.tensor):
498-
# points [N, 3]
499-
# transformation [N, 4, 4]
500-
# N,3,3 @ N,3,1 -> N,3,1 + N,3,1 -> N,3,1 -> N,3
520+
"""
521+
Transform a batch of points by a batch of transformation matrices
522+
Args:
523+
points: N,3 torch tensor, the coordinates of all N (axbxc) query points in the scaled
524+
kaolin coordinate system [-1,1]
525+
transformation: N,4,4 torch tensor, the transformation matrices
526+
Returns:
527+
transformed_points: N,3 torch tensor, the transformed coordinates
528+
"""
501529

502530
# Extract rotation and translation components
503531
rotation = transformation[:, :3, :3].to(points)
@@ -609,7 +637,9 @@ def split_chunks(
609637
aabb: o3d.geometry.AxisAlignedBoundingBox(),
610638
chunk_m: float = 100.0
611639
):
612-
640+
"""
641+
Split a large point cloud into bounding box chunks
642+
"""
613643
if not pc.has_points():
614644
return None
615645

@@ -680,7 +710,9 @@ def split_chunks(
680710
def deskewing(
681711
points: torch.tensor, ts: torch.tensor, pose: torch.tensor, ts_mid_pose=0.5
682712
):
683-
713+
"""
714+
Deskew a batch of points at timestamp ts by a relative transformation matrix
715+
"""
684716
if ts is None:
685717
return points # no deskewing
686718

@@ -711,7 +743,9 @@ def deskewing(
711743

712744

713745
def tranmat_close_to_identity(mats: np.ndarray, rot_thre: float, tran_thre: float):
714-
746+
"""
747+
Check if a batch of transformation matrices is close to identity
748+
"""
715749
rot_diff = np.abs(mats[:3, :3] - np.identity(3))
716750

717751
rot_close_to_identity = np.all(rot_diff < rot_thre)
@@ -781,7 +815,9 @@ def feature_pca_torch(data, principal_components = None,
781815
return data_pca, principal_components
782816

783817
def plot_timing_detail(time_table: np.ndarray, saving_path: str, with_loop=False):
784-
818+
"""
819+
Plot the timing detail for processing per frame
820+
"""
785821
frame_count = time_table.shape[0]
786822
time_table_ms = time_table * 1e3
787823

utils/tracker.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,21 @@ def tracking(
5454
loop_reg: bool = False,
5555
vis_result: bool = False,
5656
):
57-
57+
"""
58+
Perform tracking
59+
Args:
60+
source_points: N,3 torch tensor, the coordinates of all N query points
61+
init_pose: 4,4 torch tensor, the initial pose
62+
source_colors: N,3 torch tensor, the colors of all N query points
63+
source_normals: N,3 torch tensor, the normals of all N query points
64+
source_sdf: N torch tensor, the SDF values of all N query points
65+
cur_ts: float, the timestamp of the current frame
66+
loop_reg: bool, whether this is a registration for loop closure
67+
vis_result: bool, whether to visualize the result
68+
Returns:
69+
T: 4,4 torch tensor, the final pose
70+
cov_mat: 6,6 torch tensor, the covariance matrix
71+
"""
5872
if init_pose is None:
5973
T = torch.eye(4, dtype=torch.float64, device=self.device)
6074
else:
@@ -365,7 +379,9 @@ def registration_step(
365379
lm_lambda=0.0,
366380
vis_weight_pc=False,
367381
): # if lm_lambda = 0, then it's Gaussian Newton Optimization
368-
382+
"""
383+
Perform one step of registration
384+
"""
369385
T0 = get_time()
370386

371387
colors_on = colors is not None and self.config.color_on
@@ -757,6 +773,9 @@ def ct_registration_step(
757773

758774
# math tools
759775
def skew(v):
776+
"""
777+
Compute the skew-symmetric matrix of a 3D vector
778+
"""
760779
S = torch.zeros(3, 3, device=v.device, dtype=v.dtype)
761780
S[0, 1] = -v[2]
762781
S[0, 2] = v[1]
@@ -765,7 +784,9 @@ def skew(v):
765784

766785

767786
def expmap(axis_angle: torch.Tensor):
768-
787+
"""
788+
Convert an axis-angle representation to a rotation matrix
789+
"""
769790
angle = axis_angle.norm()
770791
axis = axis_angle / angle
771792
eye = torch.eye(3, device=axis_angle.device, dtype=axis_angle.dtype)
@@ -777,6 +798,9 @@ def expmap(axis_angle: torch.Tensor):
777798

778799

779800
def rotation_matrix_to_axis_angle(R):
801+
"""
802+
Convert a rotation matrix to an axis-angle representation
803+
"""
780804
# epsilon = 1e-8 # A small value to handle numerical precision issues
781805
# Ensure the input matrix is a valid rotation matrix
782806
assert torch.is_tensor(R) and R.shape == (3, 3), "Invalid rotation matrix"

0 commit comments

Comments
 (0)