We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a62675b commit 5594850Copy full SHA for 5594850
examples/keras_rs/stu.py
@@ -317,10 +317,7 @@ def keras_concat_2D_jagged_resolver(
317
318
def keras_layer_norm(x, weight, bias, eps):
319
# Functional Layer Norm steps
320
- mean = ops.mean(x, axis=-1, keepdims=True)
321
- variance = ops.mean(ops.square(x - mean), axis=-1, keepdims=True)
322
- std = ops.sqrt(variance + eps)
323
- normalized_x = (x - mean) / std
+ normalized_x = ops.layer_norm(x, axis=-1, epsilon=eps)
324
return normalized_x * weight + bias
325
326
0 commit comments