|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2025 The Google Research Authors. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +"""Reordering methods for point clouds.""" |
| 16 | + |
| 17 | +import math |
| 18 | + |
| 19 | +import jax |
| 20 | +import jax.numpy as jnp |
| 21 | +import numpy as np |
| 22 | +import optax |
| 23 | +import ott_jax as ott |
| 24 | + |
| 25 | +from google3.research.neuromancer.jax import spatial |
| 26 | + |
| 27 | + |
| 28 | +@jax.jit |
| 29 | +def reorder_axis(pc: jax.Array, axis_index: int) -> jax.Array: |
| 30 | + """Reorders points in a batched point cloud along a given axis index. |
| 31 | +
|
| 32 | + Args: |
| 33 | + pc: Batched point cloud of shape (..., point_cloud_size, feature_dimension). |
| 34 | + axis_index: Index of the feature dimension (last dim) to sort by. |
| 35 | +
|
| 36 | + Returns: |
| 37 | + Reordered batched point cloud of the same shape. |
| 38 | + """ |
| 39 | + keys_to_sort = pc[..., axis_index] |
| 40 | + argsort_indices = jnp.argsort(keys_to_sort, axis=-1) |
| 41 | + # jnp.argpartition into halves is ~2x slower |
| 42 | + |
| 43 | + expanded_indices = jnp.expand_dims(argsort_indices, axis=-1) |
| 44 | + return jnp.take_along_axis(pc, expanded_indices, axis=-2) |
| 45 | + |
| 46 | + |
| 47 | +@jax.jit |
| 48 | +def reorder_z_sfc( |
| 49 | + pc: jax.Array, |
| 50 | +) -> jax.Array: |
| 51 | + """Reorders a point cloud using Z-order-like Space-Filling Curve sorting. |
| 52 | +
|
| 53 | + Args: |
| 54 | + pc: Batched point cloud of shape (batch_size, point_cloud_size, |
| 55 | + feature_dimension). Assumes point_cloud_size is a power of 2. |
| 56 | +
|
| 57 | + Returns: |
| 58 | + Reordered batched point cloud of the same shape. |
| 59 | + """ |
| 60 | + depth = int(np.log2(pc.shape[1])) |
| 61 | + batch_size, point_cloud_size, feature_dim = pc.shape |
| 62 | + |
| 63 | + current_pc = pc |
| 64 | + for d in range(depth): |
| 65 | + num_sections = 2**d |
| 66 | + section_size = point_cloud_size // num_sections |
| 67 | + axis_to_sort = d % feature_dim |
| 68 | + reshaped_pc = current_pc.reshape( |
| 69 | + batch_size, num_sections, section_size, feature_dim |
| 70 | + ) |
| 71 | + processed_sections = reorder_axis(reshaped_pc, axis_to_sort) |
| 72 | + |
| 73 | + current_pc = processed_sections.reshape( |
| 74 | + batch_size, point_cloud_size, feature_dim |
| 75 | + ) |
| 76 | + assert pc.shape == current_pc.shape |
| 77 | + return current_pc |
| 78 | + |
| 79 | + |
| 80 | +@jax.jit |
| 81 | +def reorder_z_sfc_with_reverse( |
| 82 | + pc: jax.Array, |
| 83 | +) -> tuple[jax.Array, jax.Array]: |
| 84 | + """Reorders a point cloud using Z-order-like Space-Filling Curve sorting. |
| 85 | +
|
| 86 | + Args: |
| 87 | + pc: Batched point cloud of shape (batch_size, point_cloud_size, |
| 88 | + feature_dimension). Assumes point_cloud_size is a power of 2. |
| 89 | +
|
| 90 | + Returns: |
| 91 | + A tuple containing: |
| 92 | + - Reordered batched point cloud of the same shape. |
| 93 | + - Reverse indices: An array of shape (batch_size, point_cloud_size) such |
| 94 | + that `pc == reordered_pc[jnp.arange(batch_size)[:, None], |
| 95 | + reverse_indices]`. |
| 96 | + """ |
| 97 | + batch_size, point_cloud_size, feature_dimension = pc.shape |
| 98 | + depth = int(np.log2(point_cloud_size)) |
| 99 | + assert ( |
| 100 | + 2**depth == point_cloud_size |
| 101 | + ), f'point_cloud_size must be a power of 2, but is {point_cloud_size}' |
| 102 | + |
| 103 | + current_pc = pc |
| 104 | + current_indices = jnp.arange(point_cloud_size) |
| 105 | + current_indices = jnp.broadcast_to( |
| 106 | + current_indices, (batch_size, point_cloud_size) |
| 107 | + ) |
| 108 | + |
| 109 | + for d in range(depth): |
| 110 | + num_sections = 2**d |
| 111 | + section_size = current_pc.shape[1] // num_sections |
| 112 | + axis_to_sort = d % feature_dimension |
| 113 | + reshaped_pc = current_pc.reshape( |
| 114 | + batch_size, num_sections, section_size, feature_dimension |
| 115 | + ) |
| 116 | + reshaped_indices = current_indices.reshape( |
| 117 | + batch_size, num_sections, section_size |
| 118 | + ) |
| 119 | + keys_to_sort = reshaped_pc[..., :, axis_to_sort] |
| 120 | + argsort_indices_sections = jnp.argsort(keys_to_sort, axis=-1) |
| 121 | + expanded_argsort_for_pc = jnp.expand_dims(argsort_indices_sections, axis=-1) |
| 122 | + reordered_pc_sections = jnp.take_along_axis( |
| 123 | + reshaped_pc, |
| 124 | + expanded_argsort_for_pc, |
| 125 | + axis=-2, |
| 126 | + ) |
| 127 | + reordered_indices_sections = jnp.take_along_axis( |
| 128 | + reshaped_indices, |
| 129 | + argsort_indices_sections, |
| 130 | + axis=-1, |
| 131 | + ) |
| 132 | + current_pc = reordered_pc_sections.reshape( |
| 133 | + batch_size, point_cloud_size, feature_dimension |
| 134 | + ) |
| 135 | + current_indices = reordered_indices_sections.reshape( |
| 136 | + batch_size, point_cloud_size |
| 137 | + ) |
| 138 | + assert pc.shape == current_pc.shape |
| 139 | + |
| 140 | + reverse_indices = jnp.argsort(current_indices, axis=1) |
| 141 | + |
| 142 | + return current_pc, reverse_indices |
| 143 | + |
| 144 | + |
| 145 | +def reorder_distance_from_centroid( |
| 146 | + pc: jax.Array, reverse: bool = False |
| 147 | +) -> jax.Array: |
| 148 | + """Reorders points based on distance from the point cloud centroid. |
| 149 | +
|
| 150 | + Args: |
| 151 | + pc: Batched point cloud (batch_size, point_cloud_size, feature_dimension). |
| 152 | + reverse: If True, sort from farthest to closest. |
| 153 | +
|
| 154 | + Returns: |
| 155 | + Reordered batched point cloud. |
| 156 | + """ |
| 157 | + centroids = jnp.mean(pc, axis=1, keepdims=True) |
| 158 | + distances_sq = jnp.sum(jnp.square(pc - centroids), axis=2) |
| 159 | + argsort_indices = jnp.argsort(distances_sq, axis=1) |
| 160 | + return jnp.take_along_axis( |
| 161 | + pc, jnp.expand_dims(argsort_indices, axis=2), axis=1 |
| 162 | + )[:, :: (-1 if reverse else 1)] |
| 163 | + |
| 164 | + |
| 165 | +def reorder_distance_from_origin( |
| 166 | + pc: jax.Array, reverse: bool = False |
| 167 | +) -> jax.Array: |
| 168 | + """Reorders points based on distance from the origin (0,0,0). |
| 169 | +
|
| 170 | + Args: |
| 171 | + pc: Batched point cloud (batch_size, point_cloud_size, feature_dimension). |
| 172 | + reverse: If True, sort from farthest to closest. |
| 173 | +
|
| 174 | + Returns: |
| 175 | + Reordered batched point cloud. |
| 176 | + """ |
| 177 | + distances_sq = jnp.sum(jnp.square(pc), axis=2) |
| 178 | + argsort_indices = jnp.argsort(distances_sq, axis=1) |
| 179 | + return jnp.take_along_axis( |
| 180 | + pc, jnp.expand_dims(argsort_indices, axis=2), axis=1 |
| 181 | + )[:, :: (-1 if reverse else 1)] |
| 182 | + |
| 183 | + |
| 184 | +def reorder_pc_ot( |
| 185 | + pc_a: jax.Array, pc_b: jax.Array |
| 186 | +) -> tuple[jax.Array, jax.Array]: |
| 187 | + """Reorders point clouds using optimal transport matching via Sinkhorn. |
| 188 | +
|
| 189 | + Args: |
| 190 | + pc_a: First point cloud (point_cloud_size, feature_dimension). |
| 191 | + pc_b: Second point cloud (point_cloud_size, feature_dimension). |
| 192 | +
|
| 193 | + Returns: |
| 194 | + Tuple of reordered point clouds |
| 195 | + """ |
| 196 | + assert pc_a.shape == pc_b.shape |
| 197 | + geom = ott.geometry.pointcloud.PointCloud(pc_a, pc_b) |
| 198 | + ot = ott.solvers.linear.sinkhorn.solve(geom) |
| 199 | + ind_a, ind_b = optax.assignment.hungarian_algorithm(-ot.matrix) |
| 200 | + # TODO(riegerfr): get indices directly without using optax/hungarian, |
| 201 | + # then assert valid permutation |
| 202 | + return pc_a[ind_a], pc_b[ind_b] |
| 203 | + |
| 204 | + |
| 205 | +vmap_ot = jax.vmap(reorder_pc_ot, in_axes=(0, 0), out_axes=(0, 0)) |
| 206 | + |
| 207 | + |
| 208 | +def reorder_origin_fps( |
| 209 | + pc: jax.Array, closest_to_origin: bool = False |
| 210 | +) -> jax.Array: |
| 211 | + """Reorders points using FPS starting near/far from origin. |
| 212 | +
|
| 213 | + Args: |
| 214 | + pc: Point cloud data. (shape: (batch_size, n_points, 3)) |
| 215 | + closest_to_origin: Whether to order by closest to origin. |
| 216 | +
|
| 217 | + Returns: |
| 218 | + Reordered point cloud. |
| 219 | + """ |
| 220 | + norms = jnp.linalg.norm(pc, axis=2) |
| 221 | + max_norm_indices = ( |
| 222 | + jnp.argmin(norms, axis=1) |
| 223 | + if closest_to_origin |
| 224 | + else jnp.argmax(norms, axis=1) |
| 225 | + ) |
| 226 | + batch_indices = jnp.arange(pc.shape[0]) |
| 227 | + |
| 228 | + max_norm_point = pc[batch_indices, max_norm_indices] |
| 229 | + first_point = pc[:, 0] |
| 230 | + |
| 231 | + pc = pc.at[:, 0].set(max_norm_point) |
| 232 | + pc = pc.at[batch_indices, max_norm_indices].set(first_point) |
| 233 | + pc = spatial.subsample_points(pc, pc.shape[1])[0] |
| 234 | + return pc |
| 235 | + |
| 236 | + |
| 237 | +def reorder_named(coord: jax.Array, name: str) -> jax.Array: |
| 238 | + """Reorders points based on a named reordering method. |
| 239 | +
|
| 240 | + Args: |
| 241 | + coord: Point cloud data. (shape: (batch_size, n_points, 3)) |
| 242 | + name: Name of the reordering method. |
| 243 | +
|
| 244 | + Returns: |
| 245 | + Reordered point cloud. |
| 246 | +
|
| 247 | + Raises: |
| 248 | + ValueError: If the reordering method is not recognized. |
| 249 | + """ |
| 250 | + if name == 'no': |
| 251 | + return coord |
| 252 | + base_method = name.split('_')[0] |
| 253 | + if name.endswith('_min'): |
| 254 | + max_dist = False |
| 255 | + else: |
| 256 | + max_dist = True |
| 257 | + if '_recursive' in name: |
| 258 | + recursive = True |
| 259 | + else: |
| 260 | + recursive = False |
| 261 | + if '_first_' in name: |
| 262 | + first_n = int(name.split('_first_')[-1].split('_')[0]) |
| 263 | + else: |
| 264 | + first_n = 0 |
| 265 | + return reorder(coord, base_method, max_dist, recursive, first_n) |
| 266 | + |
| 267 | + |
| 268 | +def reorder( |
| 269 | + coord: jax.Array, |
| 270 | + base_method: str | None = None, |
| 271 | + max_dist: bool = True, |
| 272 | + recursive: bool = False, |
| 273 | + first_n: int = 0, |
| 274 | +) -> jax.Array: |
| 275 | + """Reorders points based on a named reordering method. |
| 276 | +
|
| 277 | + Args: |
| 278 | + coord: Point cloud data. (shape: (batch_size, n_points, 3)) |
| 279 | + base_method: Name of the reordering method. |
| 280 | + max_dist: Whether to order by max distance. |
| 281 | + recursive: Whether to reorder recursively. |
| 282 | + first_n: Number of points to reorder first. |
| 283 | +
|
| 284 | + Returns: |
| 285 | + Reordered point cloud. |
| 286 | +
|
| 287 | + Raises: |
| 288 | + ValueError: If the reordering method is not recognized. |
| 289 | + """ |
| 290 | + |
| 291 | + if base_method is None: |
| 292 | + return coord |
| 293 | + |
| 294 | + reorder_fn = { |
| 295 | + 'z': reorder_z_sfc, |
| 296 | + # TODO(riegerfr): add hilbert |
| 297 | + 'origin': reorder_distance_from_origin, |
| 298 | + 'revorigin': lambda x: reorder_distance_from_origin(x, reverse=True), |
| 299 | + 'centroid': reorder_distance_from_centroid, |
| 300 | + 'revcentroid': lambda x: reorder_distance_from_centroid(x, reverse=True), |
| 301 | + 'fps': lambda x: reorder_origin_fps(x, closest_to_origin=max_dist), |
| 302 | + }[base_method] |
| 303 | + |
| 304 | + assert not (first_n > 0 and recursive) |
| 305 | + |
| 306 | + if first_n > 0: |
| 307 | + assert base_method != 'fps', 'fps cannot be used with first_n' |
| 308 | + coord = reorder_origin_fps(coord, closest_to_origin=not max_dist) |
| 309 | + coord = jnp.concatenate( |
| 310 | + ( |
| 311 | + reorder_fn(coord[:, :first_n]), |
| 312 | + reorder_fn(coord[:, first_n:]), |
| 313 | + ), |
| 314 | + axis=1, |
| 315 | + ) |
| 316 | + elif recursive: |
| 317 | + assert base_method != 'fps', 'fps cannot be used with recursive' |
| 318 | + coord = reorder_origin_fps(coord, closest_to_origin=not max_dist) |
| 319 | + newcoord = jnp.concatenate( |
| 320 | + [ |
| 321 | + reorder_fn(chunk) |
| 322 | + for chunk in jnp.split( |
| 323 | + coord, |
| 324 | + [ |
| 325 | + 2 ** (i + 1) |
| 326 | + for i in range(int(math.log2(coord.shape[1])) - 1) |
| 327 | + ], |
| 328 | + axis=1, |
| 329 | + ) |
| 330 | + ], |
| 331 | + axis=1, |
| 332 | + ) |
| 333 | + assert newcoord.shape == coord.shape |
| 334 | + coord = newcoord |
| 335 | + else: |
| 336 | + coord = reorder_fn(coord) |
| 337 | + |
| 338 | + return coord |
0 commit comments