Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions e3nn_jax/_src/so3grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def __mul__(self, other: Union[float, "SO3Signal"]) -> "SO3Signal":
def __truediv__(self, other: float) -> "SO3Signal":
return self * (1 / other)

def apply(self, func: Callable[..., jnp.ndarray]) -> "SO3Signal":
"""Apply a pointwise function to the signal."""
return SO3Signal(self.s2_signals.apply(func))

def vmap_over_batch_dims(
self, func: Callable[..., jnp.ndarray]
) -> Callable[..., jnp.ndarray]:
Expand Down
19 changes: 19 additions & 0 deletions tests/_src/so3grid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,22 @@ def test_argmax(seed: int):
R_argmax, _ = sig.argmax()

assert jnp.allclose(func(R_argmax), func(R_argmax_expected), rtol=1e-2)


def test_apply():
sig = SO3Signal.from_function(
lambda R: jnp.trace(R @ R),
res_beta=40,
res_alpha=39,
res_theta=40,
quadrature="gausslegendre",
)
sig_applied = sig.apply(jnp.exp)
sig_expected = SO3Signal.from_function(
lambda R: jnp.exp(jnp.trace(R @ R)),
res_beta=40,
res_alpha=39,
res_theta=40,
quadrature="gausslegendre",
)
assert jnp.allclose(sig_applied.grid_values, sig_expected.grid_values)
Loading