File tree Expand file tree Collapse file tree 2 files changed +29
-0
lines changed
py/torch_tensorrt/dynamo/lowering/passes Expand file tree Collapse file tree 2 files changed +29
-0
lines changed Original file line number Diff line number Diff line change 1818from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
1919from .remove_num_users_is_0_nodes import remove_num_users_is_0_nodes
2020from .repair_input_as_output import repair_input_as_output
21+ from .replace_fused_rms_norm import replace_fused_rms_norm
2122from .replace_max_pool_with_indices import replace_max_pool_with_indices
2223from .rule_based_autocast import rule_based_autocast
2324
2425pre_lowering_pass_list = [
2526 remove_detach ,
2627 remove_assert_nodes ,
2728 rule_based_autocast ,
29+ replace_fused_rms_norm ,
2830]
2931
3032post_lowering_pass_list = [
Original file line number Diff line number Diff line change 1+ import logging
2+
3+ import torch
4+ from torch_tensorrt .dynamo ._settings import CompilationSettings
5+
6+ logger = logging .getLogger (__name__ )
7+
8+
9+ def replace_fused_rms_norm (
10+ gm : torch .fx .GraphModule , settings : CompilationSettings
11+ ) -> torch .fx .GraphModule :
12+ """Replace fused rms norm ops in the graph"""
13+ count = 0
14+ for node in gm .graph .nodes :
15+ if node .target == torch .ops .aten ._fused_rms_norm .default :
16+ # Replace fused rms norm with standard rms norm
17+ new_node = gm .graph .call_function (
18+ torch .nn .functional .rms_norm .default ,
19+ args = node .args ,
20+ )
21+ gm .graph .replace_node_with_new_node (node , new_node )
22+ gm .graph .erase_node (node )
23+ count += 1
24+
25+ logger .debug (f"Replaced { count } fused rms norm nodes:\n { gm .graph } " )
26+
27+ return gm
You can’t perform that action at this time.
0 commit comments