Skip to content

Commit e74cd40

Browse files
committed
add support for IrrepsArray weights in Linear (only scalars)
1 parent 7e6d039 commit e74cd40

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

e3nn_jax/_src/linear.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,11 +234,11 @@ def get_parameter(name: str, path_shape: Tuple[int, ...], weight_std: float):
234234

235235
self.get_parameter = get_parameter
236236

237-
def __call__(self, weights: Optional[jnp.ndarray], input: IrrepsArray = None) -> IrrepsArray:
237+
def __call__(self, weights: Optional[Union[IrrepsArray, jnp.ndarray]], input: IrrepsArray = None) -> IrrepsArray:
238238
"""Apply the linear operator.
239239
240240
Args:
241-
weights (optional jnp.ndarray): scalar weights that are contracted with free parameters.
241+
weights (optional IrrepsArray or jnp.ndarray): scalar weights that are contracted with free parameters.
242242
An array of shape ``(..., num_weights)``. Broadcasting with `input` is supported.
243243
input (IrrepsArray): input irreps-array of shape ``(..., [channel_in,] irreps_in.dim)``.
244244
Broadcasting with `weights` is supported.
@@ -285,6 +285,11 @@ def __call__(self, weights: Optional[jnp.ndarray], input: IrrepsArray = None) ->
285285
f = jax.vmap(f)
286286
output = f(input)
287287
else:
288+
if isinstance(weights, IrrepsArray):
289+
if not weights.irreps.is_scalar():
290+
raise ValueError("weights must be scalar")
291+
weights = weights.array
292+
288293
shape = jnp.broadcast_shapes(input.shape[:-1], weights.shape[:-1])
289294
input = input.broadcast_to(shape + (-1,))
290295
weights = jnp.broadcast_to(weights, shape + weights.shape[-1:])

0 commit comments

Comments
 (0)