@@ -234,11 +234,11 @@ def get_parameter(name: str, path_shape: Tuple[int, ...], weight_std: float):
234234
235235 self .get_parameter = get_parameter
236236
237- def __call__ (self , weights : Optional [jnp .ndarray ], input : IrrepsArray = None ) -> IrrepsArray :
237+ def __call__ (self , weights : Optional [Union [ IrrepsArray , jnp .ndarray ] ], input : IrrepsArray = None ) -> IrrepsArray :
238238 """Apply the linear operator.
239239
240240 Args:
241- weights (optional jnp.ndarray): scalar weights that are contracted with free parameters.
241+ weights (optional IrrepsArray or jnp.ndarray): scalar weights that are contracted with free parameters.
242242 An array of shape ``(..., num_weights)``. Broadcasting with `input` is supported.
243243 input (IrrepsArray): input irreps-array of shape ``(..., [channel_in,] irreps_in.dim)``.
244244 Broadcasting with `weights` is supported.
@@ -285,6 +285,11 @@ def __call__(self, weights: Optional[jnp.ndarray], input: IrrepsArray = None) ->
285285 f = jax .vmap (f )
286286 output = f (input )
287287 else :
288+ if isinstance (weights , IrrepsArray ):
289+ if not weights .irreps .is_scalar ():
290+ raise ValueError ("weights must be scalar" )
291+ weights = weights .array
292+
288293 shape = jnp .broadcast_shapes (input .shape [:- 1 ], weights .shape [:- 1 ])
289294 input = input .broadcast_to (shape + (- 1 ,))
290295 weights = jnp .broadcast_to (weights , shape + weights .shape [- 1 :])
0 commit comments