@@ -41,6 +41,10 @@ def res_beta(self) -> int:
4141 def res_alpha (self ) -> int :
4242 return self .s2_signals .res_alpha
4343
44+ @property
45+ def grid_values (self ) -> jnp .ndarray :
46+ return self .s2_signals .grid_values
47+
4448 @property
4549 def res_theta (self ) -> int :
4650 return self .s2_signals .shape [- 3 ]
@@ -124,13 +128,43 @@ def __mul__(self, other: Union[float, "SO3Signal"]) -> "SO3Signal":
124128 def __truediv__ (self , other : float ) -> "SO3Signal" :
125129 return self * (1 / other )
126130
131+ def vmap_over_batch_dims (
132+ self , func : Callable [..., jnp .ndarray ]
133+ ) -> Callable [..., jnp .ndarray ]:
134+ """Apply a function to the signal while preserving the batch dimensions."""
135+ for _ in range (len (self .batch_dims )):
136+ func = jax .vmap (func )
137+ return func
138+
139+ def argmax (self ) -> Tuple [jnp .ndarray , Tuple [jnp .ndarray , jnp .ndarray , jnp .ndarray ]]:
140+ """Find the rotation (and corresponding grid indices) with the maximum value of the signal."""
141+ # Get flattened argmax
142+ flat_index = jnp .argmax (self .grid_values .reshape (* self .shape [:- 3 ], - 1 ), axis = - 1 )
143+
144+ # Convert flat index back to indices for theta, beta, alpha
145+ theta_idx , beta_idx , alpha_idx = jnp .unravel_index (flat_index , self .shape [- 3 :])
146+
147+ # Compute axis.
148+ axis = self .s2_signals .grid_vectors [..., beta_idx , alpha_idx , :]
149+ assert axis .shape == (* self .batch_dims , 3 )
150+
151+ # Compute angle.
152+ angle = self .grid_theta [theta_idx ]
153+ assert angle .shape == (* self .batch_dims ,)
154+
155+ Rs = self .vmap_over_batch_dims (e3nn .axis_angle_to_matrix )(axis , angle )
156+ assert Rs .shape == (* self .batch_dims , 3 , 3 )
157+
158+ return Rs , (theta_idx , beta_idx , alpha_idx )
159+
160+ def replace_values (self , grid_values : jnp .ndarray ) -> "SO3Signal" :
161+ """Replace the values of the signal with the given grid_values."""
162+ return SO3Signal (self .s2_signals .replace_values (grid_values ))
163+
127164 def integrate_over_angles (self ) -> SphericalSignal :
128165 """Integrate the signal over the angles in the axis-angle parametrization."""
129166 # Account for angle-dependency in Haar measure.
130- grid_values = (
131- self .s2_signals .grid_values
132- * (1 - jnp .cos (self .grid_theta ))[..., None , None ]
133- )
167+ grid_values = self .grid_values * (1 - jnp .cos (self .grid_theta ))[..., None , None ]
134168
135169 # Trapezoidal rule for integration.
136170 delta_theta = self .grid_theta [1 ] - self .grid_theta [0 ]
@@ -141,15 +175,15 @@ def integrate_over_angles(self) -> SphericalSignal:
141175 def integrate (self ) -> float :
142176 """Numerically integrate the signal over SO(3)."""
143177 # Integrate over angles.
144- s2_signal_integrated = self .integrate_over_angles ()
145- assert s2_signal_integrated .shape == (
178+ sig_integrated = self .integrate_over_angles ()
179+ assert sig_integrated .shape == (
146180 * self .batch_dims ,
147181 self .res_beta ,
148182 self .res_alpha ,
149183 )
150184
151185 # Integrate over axes using S2 quadrature.
152- integral = s2_signal_integrated .integrate ().array .squeeze (- 1 )
186+ integral = sig_integrated .integrate ().array .squeeze (- 1 )
153187 assert integral .shape == self .batch_dims
154188
155189 # Factor of 8pi^2 from the Haar measure.
@@ -159,22 +193,22 @@ def integrate(self) -> float:
159193 def sample (self , rng : jax .random .PRNGKey ) -> jnp .ndarray :
160194 """Sample a random rotation from SO(3) using the given probability distribution."""
161195 # Integrate over angles.
162- s2_signal_integrated = self .integrate_over_angles ()
163- assert s2_signal_integrated .shape == (
196+ sig_integrated = self .integrate_over_angles ()
197+ assert sig_integrated .shape == (
164198 * self .batch_dims ,
165199 self .res_beta ,
166200 self .res_alpha ,
167201 )
168202
169203 # Sample the axis from the S2 signal (integrated over angles).
170204 axis_rng , rng = jax .random .split (rng )
171- beta_idx , alpha_idx = s2_signal_integrated .sample (axis_rng )
172- axis = s2_signal_integrated .grid_vectors [..., beta_idx , alpha_idx , :]
205+ beta_idx , alpha_idx = sig_integrated .sample (axis_rng )
206+ axis = sig_integrated .grid_vectors [..., beta_idx , alpha_idx , :]
173207 assert axis .shape == (* self .batch_dims , 3 )
174208
175209 # Choose the angle from the distribution conditioned on the axis.
176210 angle_rng , rng = jax .random .split (rng )
177- theta_probs = self .s2_signals . grid_values [..., beta_idx , alpha_idx ]
211+ theta_probs = self .grid_values [..., beta_idx , alpha_idx ]
178212 assert theta_probs .shape == (* self .batch_dims , self .res_theta )
179213
180214 # Avoid log(0) by replacing 0 with a small value.
@@ -185,8 +219,6 @@ def sample(self, rng: jax.random.PRNGKey) -> jnp.ndarray:
185219 angle = jnp .linspace (0 , 2 * jnp .pi , self .res_theta )[theta_idx ]
186220 assert angle .shape == (* self .batch_dims ,)
187221
188- axis_angle_to_matrix = e3nn .axis_angle_to_matrix
189- for _ in range (len (self .batch_dims )):
190- axis_angle_to_matrix = jax .vmap (axis_angle_to_matrix )
191- Rs = axis_angle_to_matrix (axis , angle )
222+ Rs = self .vmap_over_batch_dims (e3nn .axis_angle_to_matrix )(axis , angle )
223+ assert Rs .shape == (* self .batch_dims , 3 , 3 )
192224 return Rs
0 commit comments