Skip to content

Commit 5d72147

Browse files
committed
Add division and negation for SO3Signal.
1 parent 08a7e81 commit 5d72147

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

e3nn_jax/_src/so3grid.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,23 @@ def __mul__(self, other: Union[float, "SO3Signal"]) -> "SO3Signal":
125125

126126
return SO3Signal(self.s2_signals * other)
127127

128-
def __truediv__(self, other: float) -> "SO3Signal":
128+
def __rmul__(self, other: float) -> "SO3Signal":
129+
return self * other
130+
131+
def __neg__(self) -> "SO3Signal":
132+
return self * -1
133+
134+
def __truediv__(self, other: Union[float, "SO3Signal"]) -> "SO3Signal":
135+
if isinstance(other, SO3Signal):
136+
if self.shape != other.shape:
137+
raise ValueError(
138+
f"Shapes of the two signals do not match: {self.shape} != {other.shape}"
139+
)
140+
141+
return self.replace_values(
142+
self.grid_values / other.grid_values
143+
)
144+
129145
return self * (1 / other)
130146

131147
def apply(self, func: Callable[..., jnp.ndarray]) -> "SO3Signal":

0 commit comments

Comments
 (0)