Skip to content

Commit 121d636

Browse files
Recapture fake tensor meta data for inserted subgraphs (#4052)
Co-authored-by: cehongwang <wangcehong@gmail.com>
1 parent 98cf6f0 commit 121d636

File tree

3 files changed

+212
-13
lines changed

3 files changed

+212
-13
lines changed
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import contextlib
2+
import operator
3+
from collections import defaultdict
4+
from typing import Any, Optional
5+
6+
import sympy
7+
import torch
8+
import torch.fx
9+
from torch._dispatch.python import enable_python_dispatcher
10+
from torch._inductor.fx_utils import get_fake_args_kwargs, get_node_storage, get_storage
11+
from torch._subclasses.fake_tensor import FakeTensorMode
12+
from torch.fx.experimental.symbolic_shapes import (
13+
compute_unbacked_bindings,
14+
rebind_unbacked,
15+
statically_known_true,
16+
sym_eq,
17+
)
18+
from torch.utils._ordered_set import OrderedSet
19+
20+
21+
# Adapted from torch._inductor.fx_utils.FakeTensorUpdater
22+
class FakeTensorUpdater:
23+
"""
24+
The main idea here is that it's difficult to maintain accurate fake
25+
tensors (our primary form of metadata) for each node in our graph as we
26+
transform it.
27+
28+
The most reliable way to obtain this information is by rerunning
29+
faketensor propagation. However, in general, faketensor propagation is
30+
fairly expensive. So, instead we'd like to only rerun faketensor
31+
propagation on nodes that have changed.
32+
33+
In order to detect which nodes have changed, we first hash its node,
34+
target, and argument lists (which are immutable in FX).
35+
36+
Then, whenever we call incremental_update, we check which FX nodes have a
37+
new hash, and recompute the faketensor metadata for that node. Then, we
38+
continue to recursively compute the faketensors for all users until the
39+
fake tensors stop changing.
40+
"""
41+
42+
def __init__(self, graph: torch.fx.Graph) -> None:
43+
self.processed_hashes = OrderedSet[Any]()
44+
self.graph = graph
45+
46+
for node in self.graph.nodes:
47+
self.processed_hashes.add(self.hash_node(node))
48+
49+
def hash_node(self, node: torch.fx.Node) -> tuple[torch.fx.Node, Any, Any, Any]:
50+
return (node, node.target, id(node.args), id(node.kwargs))
51+
52+
def incremental_update(self, fake_mode: FakeTensorMode) -> None:
53+
"""Update FakeTensors on self.graph. We will try to do the minimum amount of work."""
54+
existing_storages: defaultdict[Optional[int], int] = defaultdict(int)
55+
for node in self.graph.nodes:
56+
existing_storages[get_node_storage(node)] += 1
57+
58+
def is_intlist_same(new: Any, old: Any) -> Any:
59+
return statically_known_true(sym_eq(new, old))
60+
61+
def is_fake_tensor_same(new: Any, old: Any, *, node: torch.fx.Node) -> Any:
62+
if type(new) is not type(old):
63+
return False
64+
if isinstance(new, (list, tuple)):
65+
if len(new) != len(old):
66+
return False
67+
return all(
68+
is_fake_tensor_same(new_i, old_i, node=node)
69+
for new_i, old_i in zip(new, old)
70+
)
71+
if new is None:
72+
return old is None
73+
if not isinstance(new, torch.Tensor):
74+
assert isinstance(
75+
new, (torch.SymInt, torch.SymBool, torch.SymFloat)
76+
), f"Unknown type {type(new)} in {self.graph}"
77+
return (
78+
new.node.shape_env._maybe_evaluate_static(
79+
sympy.Eq(new.node.expr, old.node.expr)
80+
)
81+
== sympy.true
82+
)
83+
if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout:
84+
return False
85+
if new.layout == torch.strided and (
86+
not is_intlist_same(new.stride(), old.stride())
87+
or not statically_known_true(
88+
new.storage_offset() == old.storage_offset()
89+
)
90+
):
91+
return False
92+
93+
if new.device != old.device:
94+
return False
95+
96+
if get_storage(new) == get_storage(old):
97+
return True
98+
99+
def any_user_may_alias(node: torch.fx.Node) -> bool:
100+
if not isinstance(node.meta["val"], torch.Tensor):
101+
# analysis too complicated on lists, can support in the future
102+
return True
103+
for user in node.users:
104+
if not (
105+
isinstance(
106+
user.target,
107+
(torch._ops.OpOverload, torch._ops.HigherOrderOperator),
108+
)
109+
):
110+
return True
111+
if isinstance(user.target, torch._ops.HigherOrderOperator):
112+
# HOPs that survive until inductor are all non-aliasing HOPs.
113+
# We will likely never support HOPs that are aliasing.
114+
continue
115+
# Strategy: do a FakeTensor prop, see if the storage aliases.
116+
# If Inductor ever gets tighter invariants on OpOverloads
117+
# (that is, we ban things like torch.ops.aten.reshape calls in the graph),
118+
# Then this could just be a fast schema lookup.
119+
is_valid, args, kwargs = get_fake_args_kwargs(user)
120+
if not is_valid:
121+
return True
122+
with (
123+
fake_mode,
124+
enable_python_dispatcher(),
125+
contextlib.ExitStack() as stack,
126+
):
127+
# Ignore unbacked symbols (if they exist): we're making
128+
# this FakeTensor and then throwing it away.
129+
if fake_mode.shape_env is not None:
130+
stack.enter_context(
131+
fake_mode.shape_env.ignore_fresh_unbacked_symbols()
132+
)
133+
new_fake_tensor = user.target(*args, **kwargs)
134+
if not isinstance(new_fake_tensor, torch.Tensor):
135+
# analysis too complicated on lists, can support in the future
136+
return True
137+
if get_storage(new_fake_tensor) == get_storage(node.meta["val"]):
138+
return True
139+
return False
140+
141+
# This is the case where it returns a completely fresh storage that's used nowhere else.
142+
# If the FakeTensor's storage is fresh and none of the node's users can alias it, then
143+
# we don't need to update this node.
144+
if (
145+
existing_storages[get_storage(old)] == 1
146+
and get_storage(new) not in existing_storages
147+
and not any_user_may_alias(node)
148+
):
149+
return True
150+
151+
return False
152+
153+
def should_process_node(node: torch.fx.Node) -> bool:
154+
# node.target for nodes returning true from this function
155+
# are called under fake mode and does not work for inductor
156+
# lowerings. We check if the node.target is an aten operator
157+
# or operator.getitem which is used when returning multiple
158+
# tensors from an op.
159+
return node.op == "call_function" and (
160+
isinstance(node.target, torch._ops.OpOverload)
161+
or node.target is operator.getitem
162+
or node.target
163+
is torch._inductor.fx_passes.reinplace._generalized_scatter
164+
)
165+
166+
to_process = OrderedSet[int]()
167+
for node in self.graph.nodes:
168+
# NB: Be very careful about skipping nodes (via continues) here
169+
# and ask for a careful review when changing this code. The
170+
# consequence for incorrect FakeTensor metadata is difficult-to-debug
171+
# silent incorrectness.
172+
if (
173+
self.hash_node(node) in self.processed_hashes
174+
and id(node) not in to_process
175+
):
176+
continue
177+
178+
if not should_process_node(node):
179+
continue
180+
181+
is_valid, args, kwargs = get_fake_args_kwargs(node)
182+
if not is_valid:
183+
continue
184+
with fake_mode, enable_python_dispatcher():
185+
new_fake_tensor = node.target(*args, **kwargs)
186+
187+
if "val" in node.meta and is_fake_tensor_same(
188+
new_fake_tensor, node.meta["val"], node=node
189+
):
190+
continue
191+
192+
rebind_unbacked(fake_mode.shape_env, node, new_fake_tensor)
193+
194+
node.meta["val"] = new_fake_tensor
195+
if (shape_env := fake_mode.shape_env) and (
196+
symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor)
197+
):
198+
# Refresh the bindings to the new symbols
199+
200+
node.meta["unbacked_bindings"] = symbol_to_path
201+
202+
existing_storages[get_node_storage(node)] += 1
203+
204+
to_process.update([id(user) for user in node.users])
205+
206+
self.processed_hashes.add(self.hash_node(node))

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
from typing import Any, Callable, Optional, Sequence, Union
44

55
import torch
6+
from torch._inductor.fx_utils import get_node_storage
7+
68
from torch_tensorrt._utils import is_tegra_platform
79
from torch_tensorrt.dynamo._settings import CompilationSettings
10+
from torch_tensorrt.dynamo.lowering.passes._FakeTensorUpdater import FakeTensorUpdater
811
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
912
trace_intermediate_node_outputs,
1013
)
@@ -131,7 +134,10 @@ def post_lowering(
131134
logging.debug(
132135
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_POST_LOWERING_PASSES}"
133136
)
137+
fake_tensor_updater = FakeTensorUpdater(gm.graph)
134138
gm = ATEN_POST_LOWERING_PASSES(gm, settings)
139+
if (fake_mode := torch._export.utils._detect_fake_mode_from_gm(gm)) is not None:
140+
fake_tensor_updater.incremental_update(fake_mode)
135141

136142
return gm
137143

py/torch_tensorrt/dynamo/lowering/passes/replace_fused_rms_norm.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def process_fused_rms_norm_node(
3737
# Calculate dimensions to normalize over (similar to layer_norm)
3838
# normalized_shape specifies the last N dimensions
3939
x_dim = len(node.meta["val"][0].shape)
40-
x_fake = node.meta["val"][0]
4140
dims_to_reduce = []
4241
for i in range(len(shape)):
4342
dims_to_reduce.append(x_dim - i - 1)
@@ -48,44 +47,32 @@ def process_fused_rms_norm_node(
4847
torch.ops.aten.mul.Tensor,
4948
args=(x, x),
5049
)
51-
x_squared_fake = x_fake * x_fake
52-
x_squared.meta["val"] = x_squared_fake
5350

5451
x_squared_sum = gm.graph.call_function(
5552
torch.ops.aten.mean.dim,
5653
args=(x_squared, dims_to_reduce, True),
5754
)
58-
x_squared_sum_fake = x_squared_fake.mean(dims_to_reduce, keepdim=True)
59-
x_squared_sum.meta["val"] = x_squared_sum_fake
6055

6156
x_squared_sum_eps = gm.graph.call_function(
6257
torch.ops.aten.add.Tensor,
6358
args=(x_squared_sum, eps),
6459
)
65-
x_squared_sum_eps_fake = x_squared_sum_fake + eps
66-
x_squared_sum_eps.meta["val"] = x_squared_sum_eps_fake
6760

6861
x_squared_sum_eps_rsqrt = gm.graph.call_function(
6962
torch.ops.aten.rsqrt.default,
7063
args=(x_squared_sum_eps,),
7164
)
72-
x_squared_sum_eps_rsqrt_fake = x_squared_sum_eps_fake.rsqrt()
73-
x_squared_sum_eps_rsqrt.meta["val"] = x_squared_sum_eps_rsqrt_fake
7465

7566
x_normalized = gm.graph.call_function(
7667
torch.ops.aten.mul.Tensor,
7768
args=(x, x_squared_sum_eps_rsqrt),
7869
)
79-
x_normalized_fake = x_fake * x_squared_sum_eps_rsqrt_fake
80-
x_normalized.meta["val"] = x_normalized_fake
8170

8271
if weight is not None:
8372
x_normalized = gm.graph.call_function(
8473
torch.ops.aten.mul.Tensor,
8574
args=(x_normalized, weight),
8675
)
87-
x_normalized_fake = x_normalized_fake * weight.meta["val"]
88-
x_normalized.meta["val"] = x_normalized_fake
8976

9077
for i, user in enumerate(list(node.users)):
9178
if user.op == "call_function" and user.target == operator.getitem:

0 commit comments

Comments
 (0)