Skip to content

Commit 5594850

Browse files
Update examples/keras_rs/stu.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent a62675b commit 5594850

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

examples/keras_rs/stu.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,7 @@ def keras_concat_2D_jagged_resolver(
317317

318318
def keras_layer_norm(x, weight, bias, eps):
319319
# Functional Layer Norm steps
320-
mean = ops.mean(x, axis=-1, keepdims=True)
321-
variance = ops.mean(ops.square(x - mean), axis=-1, keepdims=True)
322-
std = ops.sqrt(variance + eps)
323-
normalized_x = (x - mean) / std
320+
normalized_x = ops.layer_norm(x, axis=-1, epsilon=eps)
324321
return normalized_x * weight + bias
325322

326323

0 commit comments

Comments
 (0)