Skip to content

Commit 240d4f4

Browse files
authored
improve RMSNorm preformance when torch backend (#21325)
* improve rmsln preformance when torch backend * update
1 parent 9e38e04 commit 240d4f4

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

keras/src/ops/nn.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from keras.src.ops import operation_utils
1515
from keras.src.ops.operation import Operation
1616
from keras.src.ops.operation_utils import reduce_shape
17+
from keras.src.utils.python_utils import is_continuous_axis
1718

1819

1920
class Relu(Operation):
@@ -2772,6 +2773,14 @@ def rms_normalization(x, scale=1, axis=-1, epsilon=None):
27722773

27732774

27742775
def _rms_normalization(x, scale=1, axis=-1, epsilon=None):
2776+
if backend.backend() == "torch" and is_continuous_axis(axis):
2777+
import torch.nn.functional as F
2778+
2779+
if isinstance(axis, (tuple, list)):
2780+
normalized_shape = tuple([x.shape[dim] for dim in axis])
2781+
else:
2782+
normalized_shape = x.shape[axis]
2783+
return F.rms_norm(x, normalized_shape, scale, epsilon)
27752784
x = backend.convert_to_tensor(x)
27762785
if len(x.shape) == 0:
27772786
x = backend.numpy.expand_dims(x, axis=0)

keras/src/utils/python_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,24 @@
55
import types as python_types
66

77

8+
def is_continuous_axis(axis):
9+
# Used to determine whether the dimensions in an axis are continuous
10+
if isinstance(axis, int) or len(axis) == 1:
11+
return True
12+
positive_order_flag = True
13+
for i in range(len(axis) - 1):
14+
if axis[i + 1] - axis[i] != 1:
15+
positive_order_flag = False
16+
break
17+
18+
negative_order_flag = True
19+
for i in range(len(axis) - 1):
20+
if axis[i + 1] - axis[i] != 1:
21+
negative_order_flag = False
22+
break
23+
return positive_order_flag or negative_order_flag
24+
25+
826
def default(method):
927
"""Decorates a method to detect overrides in subclasses."""
1028
method._is_default = True

0 commit comments

Comments
 (0)