Skip to content

Commit 7c0a392

Browse files
committed
Fix multiplication.
1 parent 398d835 commit 7c0a392

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

e3nn_jax/_src/so3grid.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,14 @@ def from_function(
112112
return SO3Signal(s2_signals)
113113

114114
def __mul__(self, other: Union[float, "SO3Signal"]) -> "SO3Signal":
115-
if isinstance(other, float):
116-
return SO3Signal(self.s2_signals * other)
117-
118-
if self.shape != other.shape:
119-
raise ValueError(
120-
f"Shapes of the two signals do not match: {self.shape} != {other.shape}"
121-
)
122-
return SO3Signal(self.s2_signals * other.s2_signals)
115+
if isinstance(other, SO3Signal):
116+
if self.shape != other.shape:
117+
raise ValueError(
118+
f"Shapes of the two signals do not match: {self.shape} != {other.shape}"
119+
)
120+
return SO3Signal(self.s2_signals * other.s2_signals)
121+
122+
return SO3Signal(self.s2_signals * other)
123123

124124
def __truediv__(self, other: float) -> "SO3Signal":
125125
return self * (1 / other)

0 commit comments

Comments
 (0)