1515This example illustrates advanced extensibility in Torch-TensorRT through automatic plugin generation and operator lowering customization.
1616"""
1717
18- from typing import Callable , Optional , Sequence , Union
18+ from typing import Any , Callable , Optional , Sequence , Union
1919
2020import flashinfer
2121import torch
2222import torch_tensorrt
23+ from torch ._subclasses import FakeTensor
2324from torch .fx .passes .shape_prop import TensorMetadata
2425from torch_tensorrt .dynamo .lowering .passes ._aten_lowering_pass import (
2526 _aten_lowering_pass ,
@@ -51,6 +52,8 @@ def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tenso
5152def replace_rmsnorm (
5253 gm : torch .fx .GraphModule , sample_inputs : Sequence [torch .Tensor ]
5354) -> torch .fx .GraphModule :
55+ print ("before2\n " )
56+ print (gm .graph )
5457 for node in gm .graph .nodes :
5558 if (
5659 node .target == torch .ops .aten ._to_copy .default
@@ -90,13 +93,60 @@ def replace_rmsnorm(
9093 weight_mul_node = list (copy_node .users )[0 ]
9194
9295 weight = weight_mul_node .args [0 ]
96+ hidden_states_node = node .args [0 ]
9397
94- original_meta = weight_mul_node .meta .get (
98+ original_meta = hidden_states_node .meta .get (
9599 "tensor_meta" , {}
96100 )
97101 memory_format = original_meta .memory_format
102+ from torch .fx .experimental .symbolic_shapes import (
103+ ShapeEnv ,
104+ )
105+
106+ shape_env = ShapeEnv ()
98107
99108 with gm .graph .inserting_after (weight_mul_node ):
109+ input_meta = node .args [0 ].meta ["val" ]
110+ batch_size = input_meta .shape [0 ]
111+ seq_len = input_meta .shape [1 ]
112+ head_dim = input_meta .shape [2 ]
113+
114+ # Create symbolic ints for batch_size
115+ if isinstance (batch_size , int ):
116+ batch_size_unbacked_symint = (
117+ shape_env .create_unbacked_symint ()
118+ )
119+ torch ._check (
120+ batch_size_unbacked_symint >= batch_size
121+ )
122+ torch ._check (
123+ batch_size_unbacked_symint <= batch_size
124+ )
125+ elif isinstance (batch_size , torch .SymInt ):
126+ pass
127+ else :
128+ raise ValueError (
129+ "Batch size must be a sym int"
130+ )
131+
132+ # Create symbolic ints for head_dim
133+ if isinstance (head_dim , int ):
134+ head_dim_unbacked_symint = (
135+ shape_env .create_unbacked_symint ()
136+ )
137+ torch ._check (
138+ head_dim_unbacked_symint >= head_dim
139+ )
140+ torch ._check (
141+ head_dim_unbacked_symint <= head_dim
142+ )
143+ elif isinstance (head_dim , torch .SymInt ):
144+ pass
145+ else :
146+ raise ValueError (
147+ "head_dim must be a sym int"
148+ )
149+
100150 b = gm .graph .create_node (
101151 op = "call_function" ,
102152 target = torch .ops .aten .sym_size .int ,
@@ -111,19 +161,24 @@ def replace_rmsnorm(
111161 is_quantized = False ,
112162 qparams = {},
113163 )
164+
165+ batch_size = node .args [0 ].meta ["val" ].shape [0 ]
166+ b .meta ["val" ] = batch_size_unbacked_symint
167+
114168 s = gm .graph .create_node (
115169 op = "call_function" ,
116170 target = torch .ops .aten .sym_size .int ,
117171 args = (node .args [0 ], 1 ),
118172 )
119173 s .meta .update (b .meta )
120-
174+ s . meta [ "val" ] = seq_len
121175 d = gm .graph .create_node (
122176 op = "call_function" ,
123177 target = torch .ops .aten .sym_size .int ,
124178 args = (node .args [0 ], 2 ),
125179 )
126180 d .meta .update (b .meta )
181+ d .meta ["val" ] = head_dim_unbacked_symint
127182
128183 with gm .graph .inserting_after (b ):
129184 new_first_dim = gm .graph .create_node (
@@ -150,11 +205,11 @@ def replace_rmsnorm(
150205 [b_val * s_val , d_val ]
151206 ),
152207 dtype = original_meta .dtype ,
153- requires_grad = True ,
154208 stride = None ,
155209 memory_format = memory_format ,
156210 is_quantized = False ,
157211 qparams = {},
212+ requires_grad = False ,
158213 )
159214 )
160215
@@ -183,11 +238,22 @@ def replace_rmsnorm(
183238 [b , s , d ],
184239 ),
185240 )
241+ reshapback_node .meta ["tensor_meta" ] = (
242+ TensorMetadata (
243+ shape = torch .Size ([b_val , s_val , d_val ]),
244+ dtype = original_meta .dtype ,
245+ stride = None ,
246+ memory_format = memory_format ,
247+ is_quantized = False ,
248+ qparams = {},
249+ requires_grad = False ,
250+ )
251+ )
186252
253+ # reshapback_node.meta.update(weight_mul_node.meta)
187254 weight_mul_node .replace_all_uses_with (
188255 reshapback_node
189256 )
190- reshapback_node .meta .update (weight_mul_node .meta )
191257
192258 modified_graph = True
193259
@@ -207,6 +273,43 @@ def replace_rmsnorm(
207273 return gm
208274
209275
276+ @_aten_lowering_pass
277+ def set_copy_node_meta_data (
278+ gm : torch .fx .GraphModule , sample_inputs : Sequence [torch .Tensor ]
279+ ) -> torch .fx .GraphModule :
280+ for node in gm .graph .nodes :
281+ if node .target == torch .ops .aten ._to_copy .default and (
282+ "tensor_meta" not in node .meta
283+ ):
284+ input_node = node .args [0 ]
285+
286+ # Check if input has metadata
287+ if "tensor_meta" in input_node .meta :
288+ # Copy input metadata and update dtype to float32
289+ output_meta = input_node .meta ["tensor_meta" ]
290+ # output_meta.dtype = node.kwargs.get("dtype")
291+
292+ # # Assign to the _to_copy node
293+ # node.meta["tensor_meta"] = output_meta
294+ node .meta ["tensor_meta" ] = TensorMetadata (
295+ shape = output_meta .shape ,
296+ dtype = node .kwargs .get ("dtype" ),
297+ requires_grad = True ,
298+ stride = None ,
299+ memory_format = input_node .meta ["tensor_meta" ].memory_format ,
300+ is_quantized = False ,
301+ qparams = {},
302+ )
303+
304+ else :
305+ # Handle missing metadata (optional warning/logging)
306+ print (f"Warning: Input node { input_node } has no tensor_meta" )
307+
308+ gm = clean_up_graph_after_modifications (gm )
309+
310+ return gm
311+
312+
210313# 1. Create a custom config with 1 layer
211314config = LlamaConfig (
212315 vocab_size = 32000 ,
@@ -222,12 +325,14 @@ def replace_rmsnorm(
222325with torch .no_grad ():
223326 model = LlamaForCausalLM (config ).eval ().half ()
224327
328+ MAX_TOKENS = 64
329+ seq_len = torch .export .Dim ("seq_len" , min = 2 , max = MAX_TOKENS )
225330# 3. Export with static shapes
226331input_ids = torch .randint (0 , 32000 , (1 , 64 )) # Static [batch=1, seq=64]
227332exported = torch .export .export (
228333 model ,
229334 (input_ids ,),
230- dynamic_shapes = None , # Fully static
335+ dynamic_shapes = ({ 1 : seq_len },),
231336)
232337
233338# Test forward pass
@@ -238,20 +343,61 @@ def replace_rmsnorm(
238343# Export validation
239344
240345DEVICE = torch .device ("cuda:0" )
241-
242- with torch_tensorrt .logging .errors ():
243- trt_model = torch_tensorrt .dynamo .compile (
244- exported ,
245- inputs = [input_ids ],
246- enabled_precisions = {torch .float32 , torch .float16 },
247- truncate_double = True ,
248- device = DEVICE ,
249- disable_tf32 = True ,
250- use_explicit_typing = False ,
251- use_fp32_acc = True ,
252- )
253-
254- input_ids = input_ids .to (DEVICE )
255-
256- res = trt_model .forward (input_ids )
257- print (res )
346+ stream = torch .cuda .Stream ()
347+ with torch .cuda .stream (stream ):
348+ with torch_tensorrt .dynamo .Debugger (
349+ log_level = "info" ,
350+ # profile_format="trex",
351+ # save_engine_profile=True,
352+ capture_fx_graph_before = ["remove_detach" ],
353+ capture_fx_graph_after = ["replace_rmsnorm" ],
354+ logging_dir = "/home/profile/logging/torchtrt" ,
355+ engine_builder_monitor = False ,
356+ ):
357+ trt_model = torch_tensorrt .dynamo .compile (
358+ exported ,
359+ inputs = [input_ids ],
360+ enabled_precisions = {torch .float32 , torch .float16 },
361+ truncate_double = True ,
362+ device = DEVICE ,
363+ disable_tf32 = True ,
364+ use_explicit_typing = False ,
365+ use_fp32_acc = True ,
366+ use_python_runtime = True ,
367+ )
368+
369+ input_ids = input_ids .to (DEVICE )
370+
371+ res = trt_model .forward (input_ids )
372+
373+ # Benchmark TensorRT models
374+
375+ import time
376+
377+ def benchmark_model (model , input_ids , label , n_runs = 100 ):
378+ torch .cuda .synchronize ()
379+ start = time .time ()
380+ for _ in range (n_runs ):
381+ with torch .no_grad ():
382+ out = model (input_ids )
383+ torch .cuda .synchronize ()
384+ end = time .time ()
385+ print (f"{ label } : { n_runs } runs, total { (end - start ):.4f} s" )
386+ return out
387+
388+ # Warmup
389+ with torch .no_grad ():
390+ _ = trt_model (input_ids )
391+
392+ # Benchmark
393+ trt_out = benchmark_model (trt_model , input_ids , "TensorRT model" )
394+
395+ # Compare outputs
396+
397+ pytorch_logits = output .logits
398+ trt_logits = trt_out .logits
399+
400+ pytorch_logits = pytorch_logits .to (DEVICE )
401+ trt_logits = trt_logits .to (DEVICE )
402+ print ("Max abs diff:" , (pytorch_logits - trt_logits ).abs ().max ().item ())
403+ print ("Mean abs diff:" , (pytorch_logits - trt_logits ).abs ().mean ().item ())
0 commit comments