Skip to content

Commit 2bd6654

Browse files
committed
Added fuse_rms_norm lowering
1 parent ab59e43 commit 2bd6654

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
1919
from .remove_num_users_is_0_nodes import remove_num_users_is_0_nodes
2020
from .repair_input_as_output import repair_input_as_output
21+
from .replace_fused_rms_norm import replace_fused_rms_norm
2122
from .replace_max_pool_with_indices import replace_max_pool_with_indices
2223
from .rule_based_autocast import rule_based_autocast
2324

2425
pre_lowering_pass_list = [
2526
remove_detach,
2627
remove_assert_nodes,
2728
rule_based_autocast,
29+
replace_fused_rms_norm,
2830
]
2931

3032
post_lowering_pass_list = [
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import logging
2+
3+
import torch
4+
from torch_tensorrt.dynamo._settings import CompilationSettings
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
def replace_fused_rms_norm(
10+
gm: torch.fx.GraphModule, settings: CompilationSettings
11+
) -> torch.fx.GraphModule:
12+
"""Replace fused rms norm ops in the graph"""
13+
count = 0
14+
for node in gm.graph.nodes:
15+
if node.target == torch.ops.aten._fused_rms_norm.default:
16+
# Replace fused rms norm with standard rms norm
17+
new_node = gm.graph.call_function(
18+
torch.nn.functional.rms_norm.default,
19+
args=node.args,
20+
)
21+
gm.graph.replace_node_with_new_node(node, new_node)
22+
gm.graph.erase_node(node)
23+
count += 1
24+
25+
logger.debug(f"Replaced {count} fused rms norm nodes:\n{gm.graph}")
26+
27+
return gm

0 commit comments

Comments
 (0)