Skip to content

Commit 046fac7

Browse files
mjanuszcopybara-github
authored andcommitted
Move Mogen utils to the connectomics repo.
PiperOrigin-RevId: 876435537
1 parent 60156fd commit 046fac7

File tree

3 files changed

+460
-0
lines changed

3 files changed

+460
-0
lines changed

connectomics/mogen/reorder.py

Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,338 @@
1+
# coding=utf-8
2+
# Copyright 2025 The Google Research Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Reordering methods for point clouds."""
16+
17+
import math
18+
19+
import jax
20+
import jax.numpy as jnp
21+
import numpy as np
22+
import optax
23+
import ott_jax as ott
24+
25+
from google3.research.neuromancer.jax import spatial
26+
27+
28+
@jax.jit
29+
def reorder_axis(pc: jax.Array, axis_index: int) -> jax.Array:
30+
"""Reorders points in a batched point cloud along a given axis index.
31+
32+
Args:
33+
pc: Batched point cloud of shape (..., point_cloud_size, feature_dimension).
34+
axis_index: Index of the feature dimension (last dim) to sort by.
35+
36+
Returns:
37+
Reordered batched point cloud of the same shape.
38+
"""
39+
keys_to_sort = pc[..., axis_index]
40+
argsort_indices = jnp.argsort(keys_to_sort, axis=-1)
41+
# jnp.argpartition into halves is ~2x slower
42+
43+
expanded_indices = jnp.expand_dims(argsort_indices, axis=-1)
44+
return jnp.take_along_axis(pc, expanded_indices, axis=-2)
45+
46+
47+
@jax.jit
48+
def reorder_z_sfc(
49+
pc: jax.Array,
50+
) -> jax.Array:
51+
"""Reorders a point cloud using Z-order-like Space-Filling Curve sorting.
52+
53+
Args:
54+
pc: Batched point cloud of shape (batch_size, point_cloud_size,
55+
feature_dimension). Assumes point_cloud_size is a power of 2.
56+
57+
Returns:
58+
Reordered batched point cloud of the same shape.
59+
"""
60+
depth = int(np.log2(pc.shape[1]))
61+
batch_size, point_cloud_size, feature_dim = pc.shape
62+
63+
current_pc = pc
64+
for d in range(depth):
65+
num_sections = 2**d
66+
section_size = point_cloud_size // num_sections
67+
axis_to_sort = d % feature_dim
68+
reshaped_pc = current_pc.reshape(
69+
batch_size, num_sections, section_size, feature_dim
70+
)
71+
processed_sections = reorder_axis(reshaped_pc, axis_to_sort)
72+
73+
current_pc = processed_sections.reshape(
74+
batch_size, point_cloud_size, feature_dim
75+
)
76+
assert pc.shape == current_pc.shape
77+
return current_pc
78+
79+
80+
@jax.jit
81+
def reorder_z_sfc_with_reverse(
82+
pc: jax.Array,
83+
) -> tuple[jax.Array, jax.Array]:
84+
"""Reorders a point cloud using Z-order-like Space-Filling Curve sorting.
85+
86+
Args:
87+
pc: Batched point cloud of shape (batch_size, point_cloud_size,
88+
feature_dimension). Assumes point_cloud_size is a power of 2.
89+
90+
Returns:
91+
A tuple containing:
92+
- Reordered batched point cloud of the same shape.
93+
- Reverse indices: An array of shape (batch_size, point_cloud_size) such
94+
that `pc == reordered_pc[jnp.arange(batch_size)[:, None],
95+
reverse_indices]`.
96+
"""
97+
batch_size, point_cloud_size, feature_dimension = pc.shape
98+
depth = int(np.log2(point_cloud_size))
99+
assert (
100+
2**depth == point_cloud_size
101+
), f'point_cloud_size must be a power of 2, but is {point_cloud_size}'
102+
103+
current_pc = pc
104+
current_indices = jnp.arange(point_cloud_size)
105+
current_indices = jnp.broadcast_to(
106+
current_indices, (batch_size, point_cloud_size)
107+
)
108+
109+
for d in range(depth):
110+
num_sections = 2**d
111+
section_size = current_pc.shape[1] // num_sections
112+
axis_to_sort = d % feature_dimension
113+
reshaped_pc = current_pc.reshape(
114+
batch_size, num_sections, section_size, feature_dimension
115+
)
116+
reshaped_indices = current_indices.reshape(
117+
batch_size, num_sections, section_size
118+
)
119+
keys_to_sort = reshaped_pc[..., :, axis_to_sort]
120+
argsort_indices_sections = jnp.argsort(keys_to_sort, axis=-1)
121+
expanded_argsort_for_pc = jnp.expand_dims(argsort_indices_sections, axis=-1)
122+
reordered_pc_sections = jnp.take_along_axis(
123+
reshaped_pc,
124+
expanded_argsort_for_pc,
125+
axis=-2,
126+
)
127+
reordered_indices_sections = jnp.take_along_axis(
128+
reshaped_indices,
129+
argsort_indices_sections,
130+
axis=-1,
131+
)
132+
current_pc = reordered_pc_sections.reshape(
133+
batch_size, point_cloud_size, feature_dimension
134+
)
135+
current_indices = reordered_indices_sections.reshape(
136+
batch_size, point_cloud_size
137+
)
138+
assert pc.shape == current_pc.shape
139+
140+
reverse_indices = jnp.argsort(current_indices, axis=1)
141+
142+
return current_pc, reverse_indices
143+
144+
145+
def reorder_distance_from_centroid(
146+
pc: jax.Array, reverse: bool = False
147+
) -> jax.Array:
148+
"""Reorders points based on distance from the point cloud centroid.
149+
150+
Args:
151+
pc: Batched point cloud (batch_size, point_cloud_size, feature_dimension).
152+
reverse: If True, sort from farthest to closest.
153+
154+
Returns:
155+
Reordered batched point cloud.
156+
"""
157+
centroids = jnp.mean(pc, axis=1, keepdims=True)
158+
distances_sq = jnp.sum(jnp.square(pc - centroids), axis=2)
159+
argsort_indices = jnp.argsort(distances_sq, axis=1)
160+
return jnp.take_along_axis(
161+
pc, jnp.expand_dims(argsort_indices, axis=2), axis=1
162+
)[:, :: (-1 if reverse else 1)]
163+
164+
165+
def reorder_distance_from_origin(
166+
pc: jax.Array, reverse: bool = False
167+
) -> jax.Array:
168+
"""Reorders points based on distance from the origin (0,0,0).
169+
170+
Args:
171+
pc: Batched point cloud (batch_size, point_cloud_size, feature_dimension).
172+
reverse: If True, sort from farthest to closest.
173+
174+
Returns:
175+
Reordered batched point cloud.
176+
"""
177+
distances_sq = jnp.sum(jnp.square(pc), axis=2)
178+
argsort_indices = jnp.argsort(distances_sq, axis=1)
179+
return jnp.take_along_axis(
180+
pc, jnp.expand_dims(argsort_indices, axis=2), axis=1
181+
)[:, :: (-1 if reverse else 1)]
182+
183+
184+
def reorder_pc_ot(
185+
pc_a: jax.Array, pc_b: jax.Array
186+
) -> tuple[jax.Array, jax.Array]:
187+
"""Reorders point clouds using optimal transport matching via Sinkhorn.
188+
189+
Args:
190+
pc_a: First point cloud (point_cloud_size, feature_dimension).
191+
pc_b: Second point cloud (point_cloud_size, feature_dimension).
192+
193+
Returns:
194+
Tuple of reordered point clouds
195+
"""
196+
assert pc_a.shape == pc_b.shape
197+
geom = ott.geometry.pointcloud.PointCloud(pc_a, pc_b)
198+
ot = ott.solvers.linear.sinkhorn.solve(geom)
199+
ind_a, ind_b = optax.assignment.hungarian_algorithm(-ot.matrix)
200+
# TODO(riegerfr): get indices directly without using optax/hungarian,
201+
# then assert valid permutation
202+
return pc_a[ind_a], pc_b[ind_b]
203+
204+
205+
vmap_ot = jax.vmap(reorder_pc_ot, in_axes=(0, 0), out_axes=(0, 0))
206+
207+
208+
def reorder_origin_fps(
209+
pc: jax.Array, closest_to_origin: bool = False
210+
) -> jax.Array:
211+
"""Reorders points using FPS starting near/far from origin.
212+
213+
Args:
214+
pc: Point cloud data. (shape: (batch_size, n_points, 3))
215+
closest_to_origin: Whether to order by closest to origin.
216+
217+
Returns:
218+
Reordered point cloud.
219+
"""
220+
norms = jnp.linalg.norm(pc, axis=2)
221+
max_norm_indices = (
222+
jnp.argmin(norms, axis=1)
223+
if closest_to_origin
224+
else jnp.argmax(norms, axis=1)
225+
)
226+
batch_indices = jnp.arange(pc.shape[0])
227+
228+
max_norm_point = pc[batch_indices, max_norm_indices]
229+
first_point = pc[:, 0]
230+
231+
pc = pc.at[:, 0].set(max_norm_point)
232+
pc = pc.at[batch_indices, max_norm_indices].set(first_point)
233+
pc = spatial.subsample_points(pc, pc.shape[1])[0]
234+
return pc
235+
236+
237+
def reorder_named(coord: jax.Array, name: str) -> jax.Array:
238+
"""Reorders points based on a named reordering method.
239+
240+
Args:
241+
coord: Point cloud data. (shape: (batch_size, n_points, 3))
242+
name: Name of the reordering method.
243+
244+
Returns:
245+
Reordered point cloud.
246+
247+
Raises:
248+
ValueError: If the reordering method is not recognized.
249+
"""
250+
if name == 'no':
251+
return coord
252+
base_method = name.split('_')[0]
253+
if name.endswith('_min'):
254+
max_dist = False
255+
else:
256+
max_dist = True
257+
if '_recursive' in name:
258+
recursive = True
259+
else:
260+
recursive = False
261+
if '_first_' in name:
262+
first_n = int(name.split('_first_')[-1].split('_')[0])
263+
else:
264+
first_n = 0
265+
return reorder(coord, base_method, max_dist, recursive, first_n)
266+
267+
268+
def reorder(
269+
coord: jax.Array,
270+
base_method: str | None = None,
271+
max_dist: bool = True,
272+
recursive: bool = False,
273+
first_n: int = 0,
274+
) -> jax.Array:
275+
"""Reorders points based on a named reordering method.
276+
277+
Args:
278+
coord: Point cloud data. (shape: (batch_size, n_points, 3))
279+
base_method: Name of the reordering method.
280+
max_dist: Whether to order by max distance.
281+
recursive: Whether to reorder recursively.
282+
first_n: Number of points to reorder first.
283+
284+
Returns:
285+
Reordered point cloud.
286+
287+
Raises:
288+
ValueError: If the reordering method is not recognized.
289+
"""
290+
291+
if base_method is None:
292+
return coord
293+
294+
reorder_fn = {
295+
'z': reorder_z_sfc,
296+
# TODO(riegerfr): add hilbert
297+
'origin': reorder_distance_from_origin,
298+
'revorigin': lambda x: reorder_distance_from_origin(x, reverse=True),
299+
'centroid': reorder_distance_from_centroid,
300+
'revcentroid': lambda x: reorder_distance_from_centroid(x, reverse=True),
301+
'fps': lambda x: reorder_origin_fps(x, closest_to_origin=max_dist),
302+
}[base_method]
303+
304+
assert not (first_n > 0 and recursive)
305+
306+
if first_n > 0:
307+
assert base_method != 'fps', 'fps cannot be used with first_n'
308+
coord = reorder_origin_fps(coord, closest_to_origin=not max_dist)
309+
coord = jnp.concatenate(
310+
(
311+
reorder_fn(coord[:, :first_n]),
312+
reorder_fn(coord[:, first_n:]),
313+
),
314+
axis=1,
315+
)
316+
elif recursive:
317+
assert base_method != 'fps', 'fps cannot be used with recursive'
318+
coord = reorder_origin_fps(coord, closest_to_origin=not max_dist)
319+
newcoord = jnp.concatenate(
320+
[
321+
reorder_fn(chunk)
322+
for chunk in jnp.split(
323+
coord,
324+
[
325+
2 ** (i + 1)
326+
for i in range(int(math.log2(coord.shape[1])) - 1)
327+
],
328+
axis=1,
329+
)
330+
],
331+
axis=1,
332+
)
333+
assert newcoord.shape == coord.shape
334+
coord = newcoord
335+
else:
336+
coord = reorder_fn(coord)
337+
338+
return coord

0 commit comments

Comments
 (0)