Skip to content

Commit b2d063b

Browse files
committed
Revised the lowering pass according to Bo's suggestion
1 parent e760c3c commit b2d063b

File tree

1 file changed

+169
-23
lines changed

1 file changed

+169
-23
lines changed

examples/dynamo/llama2_flashinfer_rmsnorm.py

Lines changed: 169 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
This example illustrates advanced extensibility in Torch-TensorRT through automatic plugin generation and operator lowering customization.
1616
"""
1717

18-
from typing import Callable, Optional, Sequence, Union
18+
from typing import Any, Callable, Optional, Sequence, Union
1919

2020
import flashinfer
2121
import torch
2222
import torch_tensorrt
23+
from torch._subclasses import FakeTensor
2324
from torch.fx.passes.shape_prop import TensorMetadata
2425
from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
2526
_aten_lowering_pass,
@@ -51,6 +52,8 @@ def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tenso
5152
def replace_rmsnorm(
5253
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
5354
) -> torch.fx.GraphModule:
55+
print("before2\n")
56+
print(gm.graph)
5457
for node in gm.graph.nodes:
5558
if (
5659
node.target == torch.ops.aten._to_copy.default
@@ -90,13 +93,60 @@ def replace_rmsnorm(
9093
weight_mul_node = list(copy_node.users)[0]
9194

9295
weight = weight_mul_node.args[0]
96+
hidden_states_node = node.args[0]
9397

94-
original_meta = weight_mul_node.meta.get(
98+
original_meta = hidden_states_node.meta.get(
9599
"tensor_meta", {}
96100
)
97101
memory_format = original_meta.memory_format
102+
from torch.fx.experimental.symbolic_shapes import (
103+
ShapeEnv,
104+
)
105+
106+
shape_env = ShapeEnv()
98107

99108
with gm.graph.inserting_after(weight_mul_node):
109+
input_meta = node.args[0].meta["val"]
110+
batch_size = input_meta.shape[0]
111+
seq_len = input_meta.shape[1]
112+
head_dim = input_meta.shape[2]
113+
114+
# Create symbolic ints for batch_size
115+
if isinstance(batch_size, int):
116+
batch_size_unbacked_symint = (
117+
shape_env.create_unbacked_symint()
118+
)
119+
torch._check(
120+
batch_size_unbacked_symint >= batch_size
121+
)
122+
torch._check(
123+
batch_size_unbacked_symint <= batch_size
124+
)
125+
elif isinstance(batch_size, torch.SymInt):
126+
pass
127+
else:
128+
raise ValueError(
129+
"Batch size must be a sym int"
130+
)
131+
132+
# Create symbolic ints for head_dim
133+
if isinstance(head_dim, int):
134+
head_dim_unbacked_symint = (
135+
shape_env.create_unbacked_symint()
136+
)
137+
torch._check(
138+
head_dim_unbacked_symint >= head_dim
139+
)
140+
torch._check(
141+
head_dim_unbacked_symint <= head_dim
142+
)
143+
elif isinstance(head_dim, torch.SymInt):
144+
pass
145+
else:
146+
raise ValueError(
147+
"head_dim must be a sym int"
148+
)
149+
100150
b = gm.graph.create_node(
101151
op="call_function",
102152
target=torch.ops.aten.sym_size.int,
@@ -111,19 +161,24 @@ def replace_rmsnorm(
111161
is_quantized=False,
112162
qparams={},
113163
)
164+
165+
batch_size = node.args[0].meta["val"].shape[0]
166+
b.meta["val"] = batch_size_unbacked_symint
167+
114168
s = gm.graph.create_node(
115169
op="call_function",
116170
target=torch.ops.aten.sym_size.int,
117171
args=(node.args[0], 1),
118172
)
119173
s.meta.update(b.meta)
120-
174+
s.meta["val"] = seq_len
121175
d = gm.graph.create_node(
122176
op="call_function",
123177
target=torch.ops.aten.sym_size.int,
124178
args=(node.args[0], 2),
125179
)
126180
d.meta.update(b.meta)
181+
d.meta["val"] = head_dim_unbacked_symint
127182

128183
with gm.graph.inserting_after(b):
129184
new_first_dim = gm.graph.create_node(
@@ -150,11 +205,11 @@ def replace_rmsnorm(
150205
[b_val * s_val, d_val]
151206
),
152207
dtype=original_meta.dtype,
153-
requires_grad=True,
154208
stride=None,
155209
memory_format=memory_format,
156210
is_quantized=False,
157211
qparams={},
212+
requires_grad=False,
158213
)
159214
)
160215

@@ -183,11 +238,22 @@ def replace_rmsnorm(
183238
[b, s, d],
184239
),
185240
)
241+
reshapback_node.meta["tensor_meta"] = (
242+
TensorMetadata(
243+
shape=torch.Size([b_val, s_val, d_val]),
244+
dtype=original_meta.dtype,
245+
stride=None,
246+
memory_format=memory_format,
247+
is_quantized=False,
248+
qparams={},
249+
requires_grad=False,
250+
)
251+
)
186252

253+
# reshapback_node.meta.update(weight_mul_node.meta)
187254
weight_mul_node.replace_all_uses_with(
188255
reshapback_node
189256
)
190-
reshapback_node.meta.update(weight_mul_node.meta)
191257

192258
modified_graph = True
193259

@@ -207,6 +273,43 @@ def replace_rmsnorm(
207273
return gm
208274

209275

276+
@_aten_lowering_pass
277+
def set_copy_node_meta_data(
278+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
279+
) -> torch.fx.GraphModule:
280+
for node in gm.graph.nodes:
281+
if node.target == torch.ops.aten._to_copy.default and (
282+
"tensor_meta" not in node.meta
283+
):
284+
input_node = node.args[0]
285+
286+
# Check if input has metadata
287+
if "tensor_meta" in input_node.meta:
288+
# Copy input metadata and update dtype to float32
289+
output_meta = input_node.meta["tensor_meta"]
290+
# output_meta.dtype = node.kwargs.get("dtype")
291+
292+
# # Assign to the _to_copy node
293+
# node.meta["tensor_meta"] = output_meta
294+
node.meta["tensor_meta"] = TensorMetadata(
295+
shape=output_meta.shape,
296+
dtype=node.kwargs.get("dtype"),
297+
requires_grad=True,
298+
stride=None,
299+
memory_format=input_node.meta["tensor_meta"].memory_format,
300+
is_quantized=False,
301+
qparams={},
302+
)
303+
304+
else:
305+
# Handle missing metadata (optional warning/logging)
306+
print(f"Warning: Input node {input_node} has no tensor_meta")
307+
308+
gm = clean_up_graph_after_modifications(gm)
309+
310+
return gm
311+
312+
210313
# 1. Create a custom config with 1 layer
211314
config = LlamaConfig(
212315
vocab_size=32000,
@@ -222,12 +325,14 @@ def replace_rmsnorm(
222325
with torch.no_grad():
223326
model = LlamaForCausalLM(config).eval().half()
224327

328+
MAX_TOKENS = 64
329+
seq_len = torch.export.Dim("seq_len", min=2, max=MAX_TOKENS)
225330
# 3. Export with static shapes
226331
input_ids = torch.randint(0, 32000, (1, 64)) # Static [batch=1, seq=64]
227332
exported = torch.export.export(
228333
model,
229334
(input_ids,),
230-
dynamic_shapes=None, # Fully static
335+
dynamic_shapes=({1: seq_len},),
231336
)
232337

233338
# Test forward pass
@@ -238,20 +343,61 @@ def replace_rmsnorm(
238343
# Export validation
239344

240345
DEVICE = torch.device("cuda:0")
241-
242-
with torch_tensorrt.logging.errors():
243-
trt_model = torch_tensorrt.dynamo.compile(
244-
exported,
245-
inputs=[input_ids],
246-
enabled_precisions={torch.float32, torch.float16},
247-
truncate_double=True,
248-
device=DEVICE,
249-
disable_tf32=True,
250-
use_explicit_typing=False,
251-
use_fp32_acc=True,
252-
)
253-
254-
input_ids = input_ids.to(DEVICE)
255-
256-
res = trt_model.forward(input_ids)
257-
print(res)
346+
stream = torch.cuda.Stream()
347+
with torch.cuda.stream(stream):
348+
with torch_tensorrt.dynamo.Debugger(
349+
log_level="info",
350+
# profile_format="trex",
351+
# save_engine_profile=True,
352+
capture_fx_graph_before=["remove_detach"],
353+
capture_fx_graph_after=["replace_rmsnorm"],
354+
logging_dir="/home/profile/logging/torchtrt",
355+
engine_builder_monitor=False,
356+
):
357+
trt_model = torch_tensorrt.dynamo.compile(
358+
exported,
359+
inputs=[input_ids],
360+
enabled_precisions={torch.float32, torch.float16},
361+
truncate_double=True,
362+
device=DEVICE,
363+
disable_tf32=True,
364+
use_explicit_typing=False,
365+
use_fp32_acc=True,
366+
use_python_runtime=True,
367+
)
368+
369+
input_ids = input_ids.to(DEVICE)
370+
371+
res = trt_model.forward(input_ids)
372+
373+
# Benchmark TensorRT models
374+
375+
import time
376+
377+
def benchmark_model(model, input_ids, label, n_runs=100):
378+
torch.cuda.synchronize()
379+
start = time.time()
380+
for _ in range(n_runs):
381+
with torch.no_grad():
382+
out = model(input_ids)
383+
torch.cuda.synchronize()
384+
end = time.time()
385+
print(f"{label}: {n_runs} runs, total {(end - start):.4f} s")
386+
return out
387+
388+
# Warmup
389+
with torch.no_grad():
390+
_ = trt_model(input_ids)
391+
392+
# Benchmark
393+
trt_out = benchmark_model(trt_model, input_ids, "TensorRT model")
394+
395+
# Compare outputs
396+
397+
pytorch_logits = output.logits
398+
trt_logits = trt_out.logits
399+
400+
pytorch_logits = pytorch_logits.to(DEVICE)
401+
trt_logits = trt_logits.to(DEVICE)
402+
print("Max abs diff:", (pytorch_logits - trt_logits).abs().max().item())
403+
print("Mean abs diff:", (pytorch_logits - trt_logits).abs().mean().item())

0 commit comments

Comments
 (0)