@@ -127,15 +127,18 @@ def __truediv__(self, other: float) -> "SO3Signal":
127127 def integrate_over_angles (self ) -> SphericalSignal :
128128 """Integrate the signal over the angles in the axis-angle parametrization."""
129129 # Account for angle-dependency in Haar measure.
130- grid_values = self .s2_signals .grid_values * (1 - jnp .cos (self .grid_theta ))[..., None , None ]
130+ grid_values = (
131+ self .s2_signals .grid_values
132+ * (1 - jnp .cos (self .grid_theta ))[..., None , None ]
133+ )
131134
132135 # Trapezoidal rule for integration.
133136 delta_theta = self .grid_theta [1 ] - self .grid_theta [0 ]
134137 return self .s2_signals .replace_values (
135138 grid_values = jnp .sum (grid_values , axis = - 3 ) * delta_theta
136139 )
137140
138- def integrate (self ) -> SphericalSignal :
141+ def integrate (self ) -> float :
139142 """Numerically integrate the signal over SO(3)."""
140143 # Integrate over angles.
141144 s2_signal_integrated = self .integrate_over_angles ()
@@ -153,7 +156,7 @@ def integrate(self) -> SphericalSignal:
153156 integral = integral / (8 * jnp .pi ** 2 )
154157 return integral
155158
156- def sample (self , rng : jax .random .PRNGKey ):
159+ def sample (self , rng : jax .random .PRNGKey ) -> jnp . ndarray :
157160 """Sample a random rotation from SO(3) using the given probability distribution."""
158161 # Integrate over angles.
159162 s2_signal_integrated = self .integrate_over_angles ()
0 commit comments