Skip to content

Commit a595bf1

Browse files
authored
Merge pull request #92 from e3nn/so3
Fix formatting.
2 parents c47ace2 + 0f879f0 commit a595bf1

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

e3nn_jax/_src/so3grid.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ def vmap_over_batch_dims(
136136
func = jax.vmap(func)
137137
return func
138138

139-
def argmax(self) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]:
139+
def argmax(
140+
self,
141+
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]:
140142
"""Find the rotation (and corresponding grid indices) with the maximum value of the signal."""
141143
# Get flattened argmax
142144
flat_index = jnp.argmax(self.grid_values.reshape(*self.shape[:-3], -1), axis=-1)

0 commit comments

Comments
 (0)