Skip to content

Commit c47ace2

Browse files
authored
Merge pull request #91 from e3nn/so3
Add argmax for signals on SO3
2 parents 941afa1 + 8be4994 commit c47ace2

File tree

2 files changed

+71
-16
lines changed

2 files changed

+71
-16
lines changed

e3nn_jax/_src/so3grid.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ def res_beta(self) -> int:
4141
def res_alpha(self) -> int:
4242
return self.s2_signals.res_alpha
4343

44+
@property
45+
def grid_values(self) -> jnp.ndarray:
46+
return self.s2_signals.grid_values
47+
4448
@property
4549
def res_theta(self) -> int:
4650
return self.s2_signals.shape[-3]
@@ -124,13 +128,43 @@ def __mul__(self, other: Union[float, "SO3Signal"]) -> "SO3Signal":
124128
def __truediv__(self, other: float) -> "SO3Signal":
125129
return self * (1 / other)
126130

131+
def vmap_over_batch_dims(
132+
self, func: Callable[..., jnp.ndarray]
133+
) -> Callable[..., jnp.ndarray]:
134+
"""Apply a function to the signal while preserving the batch dimensions."""
135+
for _ in range(len(self.batch_dims)):
136+
func = jax.vmap(func)
137+
return func
138+
139+
def argmax(self) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]:
140+
"""Find the rotation (and corresponding grid indices) with the maximum value of the signal."""
141+
# Get flattened argmax
142+
flat_index = jnp.argmax(self.grid_values.reshape(*self.shape[:-3], -1), axis=-1)
143+
144+
# Convert flat index back to indices for theta, beta, alpha
145+
theta_idx, beta_idx, alpha_idx = jnp.unravel_index(flat_index, self.shape[-3:])
146+
147+
# Compute axis.
148+
axis = self.s2_signals.grid_vectors[..., beta_idx, alpha_idx, :]
149+
assert axis.shape == (*self.batch_dims, 3)
150+
151+
# Compute angle.
152+
angle = self.grid_theta[theta_idx]
153+
assert angle.shape == (*self.batch_dims,)
154+
155+
Rs = self.vmap_over_batch_dims(e3nn.axis_angle_to_matrix)(axis, angle)
156+
assert Rs.shape == (*self.batch_dims, 3, 3)
157+
158+
return Rs, (theta_idx, beta_idx, alpha_idx)
159+
160+
def replace_values(self, grid_values: jnp.ndarray) -> "SO3Signal":
161+
"""Replace the values of the signal with the given grid_values."""
162+
return SO3Signal(self.s2_signals.replace_values(grid_values))
163+
127164
def integrate_over_angles(self) -> SphericalSignal:
128165
"""Integrate the signal over the angles in the axis-angle parametrization."""
129166
# Account for angle-dependency in Haar measure.
130-
grid_values = (
131-
self.s2_signals.grid_values
132-
* (1 - jnp.cos(self.grid_theta))[..., None, None]
133-
)
167+
grid_values = self.grid_values * (1 - jnp.cos(self.grid_theta))[..., None, None]
134168

135169
# Trapezoidal rule for integration.
136170
delta_theta = self.grid_theta[1] - self.grid_theta[0]
@@ -141,15 +175,15 @@ def integrate_over_angles(self) -> SphericalSignal:
141175
def integrate(self) -> float:
142176
"""Numerically integrate the signal over SO(3)."""
143177
# Integrate over angles.
144-
s2_signal_integrated = self.integrate_over_angles()
145-
assert s2_signal_integrated.shape == (
178+
sig_integrated = self.integrate_over_angles()
179+
assert sig_integrated.shape == (
146180
*self.batch_dims,
147181
self.res_beta,
148182
self.res_alpha,
149183
)
150184

151185
# Integrate over axes using S2 quadrature.
152-
integral = s2_signal_integrated.integrate().array.squeeze(-1)
186+
integral = sig_integrated.integrate().array.squeeze(-1)
153187
assert integral.shape == self.batch_dims
154188

155189
# Factor of 8pi^2 from the Haar measure.
@@ -159,22 +193,22 @@ def integrate(self) -> float:
159193
def sample(self, rng: jax.random.PRNGKey) -> jnp.ndarray:
160194
"""Sample a random rotation from SO(3) using the given probability distribution."""
161195
# Integrate over angles.
162-
s2_signal_integrated = self.integrate_over_angles()
163-
assert s2_signal_integrated.shape == (
196+
sig_integrated = self.integrate_over_angles()
197+
assert sig_integrated.shape == (
164198
*self.batch_dims,
165199
self.res_beta,
166200
self.res_alpha,
167201
)
168202

169203
# Sample the axis from the S2 signal (integrated over angles).
170204
axis_rng, rng = jax.random.split(rng)
171-
beta_idx, alpha_idx = s2_signal_integrated.sample(axis_rng)
172-
axis = s2_signal_integrated.grid_vectors[..., beta_idx, alpha_idx, :]
205+
beta_idx, alpha_idx = sig_integrated.sample(axis_rng)
206+
axis = sig_integrated.grid_vectors[..., beta_idx, alpha_idx, :]
173207
assert axis.shape == (*self.batch_dims, 3)
174208

175209
# Choose the angle from the distribution conditioned on the axis.
176210
angle_rng, rng = jax.random.split(rng)
177-
theta_probs = self.s2_signals.grid_values[..., beta_idx, alpha_idx]
211+
theta_probs = self.grid_values[..., beta_idx, alpha_idx]
178212
assert theta_probs.shape == (*self.batch_dims, self.res_theta)
179213

180214
# Avoid log(0) by replacing 0 with a small value.
@@ -185,8 +219,6 @@ def sample(self, rng: jax.random.PRNGKey) -> jnp.ndarray:
185219
angle = jnp.linspace(0, 2 * jnp.pi, self.res_theta)[theta_idx]
186220
assert angle.shape == (*self.batch_dims,)
187221

188-
axis_angle_to_matrix = e3nn.axis_angle_to_matrix
189-
for _ in range(len(self.batch_dims)):
190-
axis_angle_to_matrix = jax.vmap(axis_angle_to_matrix)
191-
Rs = axis_angle_to_matrix(axis, angle)
222+
Rs = self.vmap_over_batch_dims(e3nn.axis_angle_to_matrix)(axis, angle)
223+
assert Rs.shape == (*self.batch_dims, 3, 3)
192224
return Rs

tests/_src/so3grid_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,26 @@ def test_division_scalar():
9696
sig2 = sig1 / 2.7
9797
integral2 = sig2.integrate()
9898
assert jnp.isclose(integral2, integral1 / 2.7)
99+
100+
101+
@pytest.mark.parametrize("seed", [0, 1, 2])
102+
def test_argmax(seed: int):
103+
rng = jax.random.PRNGKey(seed)
104+
F = jax.random.normal(rng, (3, 3))
105+
106+
func = lambda R: jnp.exp(jnp.trace(F.T @ R))
107+
sig = SO3Signal.from_function(
108+
func,
109+
res_beta=50,
110+
res_alpha=50,
111+
res_theta=50,
112+
quadrature="gausslegendre",
113+
)
114+
115+
U, S, VT = jnp.linalg.svd(F)
116+
R_argmax_expected = (
117+
U @ jnp.diag(jnp.asarray([1.0, 1.0, jnp.linalg.det(U @ VT)])) @ VT
118+
)
119+
R_argmax, _ = sig.argmax()
120+
121+
assert jnp.allclose(func(R_argmax), func(R_argmax_expected), rtol=1e-2)

0 commit comments

Comments
 (0)