- 
                Notifications
    
You must be signed in to change notification settings  - Fork 369
 
Description
Bug Description
To Reproduce
Steps to reproduce the behavior:
- Install the torch_tensorrt wheels found at https://pypi.jetson-ai-lab.io/jp6/cu126 (2.8 for cu126) on a Jetson Orin Nano running Jetpack 6.2
 - Try to compile a model using a static shape
 
Code Sample
dummy_input = torch.randn(1, 3, 544, 960, dtype=torch.float16).cuda()
trt_model = torch_tensorrt.compile(
model,
inputs=[dummy_input],
enabled_precision={torch.float16},
workspace_size=1 << 27
)
Stack Trace
raceback (most recent call last):
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 614, in _tree_map_with_path
return tree_map_with_path(f, tree, *dynamic_shapes, is_leaf=is_leaf)
File "/home/henry/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 2076, in tree_map_with_path
all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
File "/home/henry/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 2076, in 
all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
File "/home/henry/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1192, in flatten_up_to
helper(self, tree, subtrees)
File "/home/henry/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1189, in helper
helper(subspec, subtree, subtrees)
File "/home/henry/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1145, in helper
raise ValueError(
ValueError: Node type mismatch; expected <class 'tuple'>, but got <class 'dict'>.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/henry/git/aeolus_dev/core/autonomy.py", line 442, in 
autonomy(speed_matrix=HARD_SPEED_MATRIX, spoofing = run_with_spoofing, recording_path = hard_recording_path)
File "/home/henry/git/aeolus_dev/core/autonomy.py", line 377, in autonomy
aeolus, video_stream, teensy_state_store = spin_up_spoofed_autonomy(speed_matrix, recording_path)
File "/home/henry/git/aeolus_dev/core/autonomy.py", line 306, in spin_up_spoofed_autonomy
prime_autonomy_pipeline(my_aeolus_run, image_stream)
File "/home/henry/git/aeolus_dev/core/autonomy.py", line 224, in prime_autonomy_pipeline
predictor = ERFNetPredictor(aeolus.model_weights)
File "/home/henry/git/aeolus_dev/core/cv/nueral_net/vision_pipeline/ERFNetPredictor.py", line 38, in init
trt_model = torch_tensorrt.compile(
File "/home/henry/.virtualenvs/aeolus-system/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 286, in compile
exp_program = dynamo_trace(
File "/home/henry/.virtualenvs/aeolus-system/lib/python3.10/site-packages/torch_tensorrt/dynamo/_tracer.py", line 79, in trace
exp_program = export(
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/init.py", line 304, in export
raise e
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/init.py", line 271, in export
return _export(
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1150, in wrapper
raise e
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1116, in wrapper
ep = fn(*args, **kwargs)
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 123, in wrapper
return fn(*args, **kwargs)
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 2163, in _export
ep = _export_for_training(
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1150, in wrapper
raise e
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1116, in wrapper
ep = fn(*args, **kwargs)
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 123, in wrapper
return fn(*args, **kwargs)
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 2026, in _export_for_training
export_artifact = export_func(
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1923, in _non_strict_export
) = make_fake_inputs(
File "/home/henry/.local/lib/python3.10/site-packages/torch/_export/non_strict_utils.py", line 356, in make_fake_inputs
_check_dynamic_shapes(combined_args, dynamic_shapes)
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 1031, in _check_dynamic_shapes
_tree_map_with_path(check_shape, combined_args, dynamic_shapes, tree_name="inputs")
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 686, in _tree_map_with_path
_compare(tree_spec, other_tree_spec, [])
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 677, in _compare
_compare(
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 652, in _compare
raise_mismatch_error(
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 634, in raise_mismatch_error
raise UserError(
torch._dynamo.exc.UserError: Detected mismatch between the structure of inputs and dynamic_shapes: inputs['inputs'] is a <class 'tuple'>, but dynamic_shapes['inputs'] is a <class 'dict'>
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation
Expected behavior
I would expect the model to compile without issues without any errors related to dynamic shape since it is given a static shape
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version: 2.8
 - PyTorch Version: 2.8
 - CPU Architecture: ARM
 - OS (e.g., Linux):
 - How you installed PyTorch: Jetson-AI-Lab wheels
 - Python version: 3.10
 - CUDA version: 12.6
 - GPU models and configuration: Jetson Orin Nano
 - Any other relevant information:
 
Additional context
I believe this error is related to the function call
dynamic shapes = get_dynamic_shapes_args(mod, arg_inputs) (Line 77 of _tracer.py)
The subsequent call to export expects None for dynamic shapes if the input is a static shape but get_dynamic_shape_args returns an empty dictionary instead
Replacing the function call with dynamic_shapes = None solves the problem