Skip to content

Commit 691360d

Browse files
committed
squash the commit
1 parent b7ae84f commit 691360d

File tree

5 files changed

+660
-0
lines changed

5 files changed

+660
-0
lines changed
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import contextlib
2+
import operator
3+
from collections import defaultdict
4+
from typing import Any, Optional
5+
6+
import sympy
7+
import torch
8+
import torch.fx
9+
from torch._dispatch.python import enable_python_dispatcher
10+
from torch._inductor.fx_utils import get_fake_args_kwargs, get_node_storage, get_storage
11+
from torch._subclasses.fake_tensor import FakeTensorMode
12+
from torch.fx.experimental.symbolic_shapes import (
13+
compute_unbacked_bindings,
14+
rebind_unbacked,
15+
statically_known_true,
16+
sym_eq,
17+
)
18+
from torch.utils._ordered_set import OrderedSet
19+
20+
21+
# Adapted from torch._inductor.fx_utils.FakeTensorUpdater
22+
class FakeTensorUpdater:
23+
"""
24+
The main idea here is that it's difficult to maintain accurate fake
25+
tensors (our primary form of metadata) for each node in our graph as we
26+
transform it.
27+
28+
The most reliable way to obtain this information is by rerunning
29+
faketensor propagation. However, in general, faketensor propagation is
30+
fairly expensive. So, instead we'd like to only rerun faketensor
31+
propagation on nodes that have changed.
32+
33+
In order to detect which nodes have changed, we first hash its node,
34+
target, and argument lists (which are immutable in FX).
35+
36+
Then, whenever we call incremental_update, we check which FX nodes have a
37+
new hash, and recompute the faketensor metadata for that node. Then, we
38+
continue to recursively compute the faketensors for all users until the
39+
fake tensors stop changing.
40+
"""
41+
42+
def __init__(self, graph: torch.fx.Graph) -> None:
43+
self.processed_hashes = OrderedSet[Any]()
44+
self.graph = graph
45+
46+
for node in self.graph.nodes:
47+
self.processed_hashes.add(self.hash_node(node))
48+
49+
def hash_node(self, node: torch.fx.Node) -> tuple[torch.fx.Node, Any, Any, Any]:
50+
return (node, node.target, id(node.args), id(node.kwargs))
51+
52+
def incremental_update(self, fake_mode: FakeTensorMode) -> None:
53+
"""Update FakeTensors on self.graph. We will try to do the minimum amount of work."""
54+
existing_storages: defaultdict[Optional[int], int] = defaultdict(int)
55+
for node in self.graph.nodes:
56+
existing_storages[get_node_storage(node)] += 1
57+
58+
def is_intlist_same(new: Any, old: Any) -> Any:
59+
return statically_known_true(sym_eq(new, old))
60+
61+
def is_fake_tensor_same(new: Any, old: Any, *, node: torch.fx.Node) -> Any:
62+
if type(new) is not type(old):
63+
return False
64+
if isinstance(new, (list, tuple)):
65+
if len(new) != len(old):
66+
return False
67+
return all(
68+
is_fake_tensor_same(new_i, old_i, node=node)
69+
for new_i, old_i in zip(new, old)
70+
)
71+
if new is None:
72+
return old is None
73+
if not isinstance(new, torch.Tensor):
74+
assert isinstance(
75+
new, (torch.SymInt, torch.SymBool, torch.SymFloat)
76+
), f"Unknown type {type(new)} in {self.graph}"
77+
return (
78+
new.node.shape_env._maybe_evaluate_static(
79+
sympy.Eq(new.node.expr, old.node.expr)
80+
)
81+
== sympy.true
82+
)
83+
if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout:
84+
return False
85+
if new.layout == torch.strided and (
86+
not is_intlist_same(new.stride(), old.stride())
87+
or not statically_known_true(
88+
new.storage_offset() == old.storage_offset()
89+
)
90+
):
91+
return False
92+
93+
if new.device != old.device:
94+
return False
95+
96+
if get_storage(new) == get_storage(old):
97+
return True
98+
99+
def any_user_may_alias(node: torch.fx.Node) -> bool:
100+
if not isinstance(node.meta["val"], torch.Tensor):
101+
# analysis too complicated on lists, can support in the future
102+
return True
103+
for user in node.users:
104+
if not (
105+
isinstance(
106+
user.target,
107+
(torch._ops.OpOverload, torch._ops.HigherOrderOperator),
108+
)
109+
):
110+
return True
111+
if isinstance(user.target, torch._ops.HigherOrderOperator):
112+
# HOPs that survive until inductor are all non-aliasing HOPs.
113+
# We will likely never support HOPs that are aliasing.
114+
continue
115+
# Strategy: do a FakeTensor prop, see if the storage aliases.
116+
# If Inductor ever gets tighter invariants on OpOverloads
117+
# (that is, we ban things like torch.ops.aten.reshape calls in the graph),
118+
# Then this could just be a fast schema lookup.
119+
is_valid, args, kwargs = get_fake_args_kwargs(user)
120+
if not is_valid:
121+
return True
122+
with (
123+
fake_mode,
124+
enable_python_dispatcher(),
125+
contextlib.ExitStack() as stack,
126+
):
127+
# Ignore unbacked symbols (if they exist): we're making
128+
# this FakeTensor and then throwing it away.
129+
if fake_mode.shape_env is not None:
130+
stack.enter_context(
131+
fake_mode.shape_env.ignore_fresh_unbacked_symbols()
132+
)
133+
new_fake_tensor = user.target(*args, **kwargs)
134+
if not isinstance(new_fake_tensor, torch.Tensor):
135+
# analysis too complicated on lists, can support in the future
136+
return True
137+
if get_storage(new_fake_tensor) == get_storage(node.meta["val"]):
138+
return True
139+
return False
140+
141+
# This is the case where it returns a completely fresh storage that's used nowhere else.
142+
# If the FakeTensor's storage is fresh and none of the node's users can alias it, then
143+
# we don't need to update this node.
144+
if (
145+
existing_storages[get_storage(old)] == 1
146+
and get_storage(new) not in existing_storages
147+
and not any_user_may_alias(node)
148+
):
149+
return True
150+
151+
return False
152+
153+
def should_process_node(node: torch.fx.Node) -> bool:
154+
# node.target for nodes returning true from this function
155+
# are called under fake mode and does not work for inductor
156+
# lowerings. We check if the node.target is an aten operator
157+
# or operator.getitem which is used when returning multiple
158+
# tensors from an op.
159+
return node.op == "call_function" and (
160+
isinstance(node.target, torch._ops.OpOverload)
161+
or node.target is operator.getitem
162+
or node.target
163+
is torch._inductor.fx_passes.reinplace._generalized_scatter
164+
)
165+
166+
to_process = OrderedSet[int]()
167+
for node in self.graph.nodes:
168+
# NB: Be very careful about skipping nodes (via continues) here
169+
# and ask for a careful review when changing this code. The
170+
# consequence for incorrect FakeTensor metadata is difficult-to-debug
171+
# silent incorrectness.
172+
if (
173+
self.hash_node(node) in self.processed_hashes
174+
and id(node) not in to_process
175+
):
176+
continue
177+
178+
if not should_process_node(node):
179+
continue
180+
181+
is_valid, args, kwargs = get_fake_args_kwargs(node)
182+
if not is_valid:
183+
continue
184+
with fake_mode, enable_python_dispatcher():
185+
new_fake_tensor = node.target(*args, **kwargs)
186+
187+
if "val" in node.meta and is_fake_tensor_same(
188+
new_fake_tensor, node.meta["val"], node=node
189+
):
190+
continue
191+
192+
rebind_unbacked(fake_mode.shape_env, node, new_fake_tensor)
193+
194+
node.meta["val"] = new_fake_tensor
195+
if (shape_env := fake_mode.shape_env) and (
196+
symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor)
197+
):
198+
# Refresh the bindings to the new symbols
199+
200+
node.meta["unbacked_bindings"] = symbol_to_path
201+
202+
existing_storages[get_node_storage(node)] += 1
203+
204+
to_process.update([id(user) for user in node.users])
205+
206+
self.processed_hashes.add(self.hash_node(node))

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
from torch_tensorrt._utils import is_tegra_platform
77
from torch_tensorrt.dynamo._settings import CompilationSettings
8+
from torch_tensorrt.dynamo.lowering.passes._FakeTensorUpdater import FakeTensorUpdater
89
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
910
trace_intermediate_node_outputs,
1011
)
@@ -18,6 +19,7 @@
1819
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
1920
from .remove_num_users_is_0_nodes import remove_num_users_is_0_nodes
2021
from .repair_input_as_output import repair_input_as_output
22+
from .replace_fused_rms_norm import replace_fused_rms_norm
2123
from .replace_max_pool_with_indices import replace_max_pool_with_indices
2224
from .rule_based_autocast import rule_based_autocast
2325

@@ -28,6 +30,7 @@
2830
]
2931

3032
post_lowering_pass_list = [
33+
replace_fused_rms_norm,
3134
remove_input_alias_fixing_clones,
3235
constant_fold,
3336
repair_input_as_output,
@@ -129,7 +132,10 @@ def post_lowering(
129132
logging.debug(
130133
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_POST_LOWERING_PASSES}"
131134
)
135+
fake_tensor_updater = FakeTensorUpdater(gm.graph)
132136
gm = ATEN_POST_LOWERING_PASSES(gm, settings)
137+
if (fake_mode := torch._export.utils._detect_fake_mode_from_gm(gm)) is not None:
138+
fake_tensor_updater.incremental_update(fake_mode)
133139

134140
return gm
135141

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import logging
2+
import operator
3+
4+
import torch
5+
from torch_tensorrt.dynamo._settings import CompilationSettings
6+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
7+
clean_up_graph_after_modifications,
8+
)
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def replace_fused_rms_norm(
14+
gm: torch.fx.GraphModule, settings: CompilationSettings
15+
) -> torch.fx.GraphModule:
16+
"""Replace fused rms norm ops in the graph"""
17+
count = 0
18+
for node in gm.graph.nodes:
19+
if node.target == torch.ops.aten._fused_rms_norm.default:
20+
x_normalized, rsqrt = process_fused_rms_norm_node(node, gm)
21+
count += 1
22+
23+
logger.debug(f"Replaced {count} fused rms norm nodes:\n{gm.graph}")
24+
25+
gm = clean_up_graph_after_modifications(gm)
26+
27+
return gm
28+
29+
30+
def process_fused_rms_norm_node(
31+
node: torch.fx.Node, gm: torch.fx.GraphModule
32+
) -> tuple[torch.fx.Node, torch.fx.Node]:
33+
34+
x, shape, weight, eps = node.args[0], node.args[1], node.args[2], node.args[3]
35+
if eps is None:
36+
eps = 1e-5
37+
# Calculate dimensions to normalize over (similar to layer_norm)
38+
# normalized_shape specifies the last N dimensions
39+
x_dim = len(node.meta["val"][0].shape)
40+
dims_to_reduce = []
41+
for i in range(len(shape)):
42+
dims_to_reduce.append(x_dim - i - 1)
43+
44+
with gm.graph.inserting_before(node):
45+
# Replace fused rms norm with standard rms norm
46+
x_squared = gm.graph.call_function(
47+
torch.ops.aten.mul.Tensor,
48+
args=(x, x),
49+
)
50+
51+
x_squared_sum = gm.graph.call_function(
52+
torch.ops.aten.mean.dim,
53+
args=(x_squared, dims_to_reduce, True),
54+
)
55+
56+
x_squared_sum_eps = gm.graph.call_function(
57+
torch.ops.aten.add.Tensor,
58+
args=(x_squared_sum, eps),
59+
)
60+
61+
x_squared_sum_eps_rsqrt = gm.graph.call_function(
62+
torch.ops.aten.rsqrt.default,
63+
args=(x_squared_sum_eps,),
64+
)
65+
66+
x_normalized = gm.graph.call_function(
67+
torch.ops.aten.mul.Tensor,
68+
args=(x, x_squared_sum_eps_rsqrt),
69+
)
70+
71+
if weight is not None:
72+
x_normalized = gm.graph.call_function(
73+
torch.ops.aten.mul.Tensor,
74+
args=(x_normalized, weight),
75+
)
76+
77+
for i, user in enumerate(list(node.users)):
78+
if user.op == "call_function" and user.target == operator.getitem:
79+
if i == 0:
80+
# If the getitem is extracting the first element (the output tensor)
81+
user.replace_all_uses_with(x_normalized)
82+
else:
83+
user.replace_all_uses_with(x_squared_sum_eps_rsqrt)
84+
85+
logger.debug(
86+
f"Replaced {i}-th user of fused_rms_norm node [{user}] with lowered rms_norm output [{x_normalized if i == 0 else x_squared_sum_eps_rsqrt}]"
87+
)
88+
gm.graph.erase_node(user)
89+
90+
gm.graph.erase_node(node)
91+
92+
return x_normalized, x_squared_sum_eps_rsqrt
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import pytest
2+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
3+
DYNAMO_ATEN_CONVERTERS,
4+
DYNAMO_CONVERTERS,
5+
)
6+
7+
8+
@pytest.fixture(autouse=True)
9+
def reset_torch_tensorrt_state():
10+
"""
11+
Ensure test isolation by restoring converter registry state and clearing caches.
12+
This prevents earlier tests from mutating global state (e.g., disallowed targets)
13+
which can cause different partitioning outcomes when running multiple tests.
14+
"""
15+
# Snapshot current global state
16+
original_registry = {k: list(v) for k, v in DYNAMO_ATEN_CONVERTERS.items()}
17+
original_disallowed = set(getattr(DYNAMO_CONVERTERS, "disallowed_targets", set()))
18+
original_settings = getattr(DYNAMO_CONVERTERS, "compilation_settings", None)
19+
20+
try:
21+
yield
22+
finally:
23+
# Restore converter registry
24+
DYNAMO_ATEN_CONVERTERS.clear()
25+
DYNAMO_ATEN_CONVERTERS.update(
26+
{k: list(v) for k, v in original_registry.items()}
27+
)
28+
29+
# Restore disallowed targets and compilation settings
30+
try:
31+
DYNAMO_CONVERTERS.set_disallowed_targets(original_disallowed)
32+
except Exception:
33+
pass
34+
if original_settings is not None:
35+
try:
36+
DYNAMO_CONVERTERS.set_compilation_settings(original_settings)
37+
except Exception:
38+
pass
39+
40+
# Clear caches again to avoid stale state carrying forward
41+
try:
42+
trace_atomic_graph.cache_clear()
43+
except Exception:
44+
pass

0 commit comments

Comments
 (0)