Description
Bug Description
Passing a boolean value inside a dict to kwarg_inputs
parameter of the torch_tensorrt.compile method results in
ValueError: Invalid input type <class 'bool'> encountered in the dynamo_compile input parsing. Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}
It seems that apart from collection types (list, tuple, dict), at leaf level only torch.Tensor values are allowed. This contradicts the documentation https://pytorch.org/TensorRT/py_api/torch_tensorrt.html?highlight=compile which states:
kwarg_inputs: Optional[dict[Any, Any]] = None
To Reproduce
Steps to reproduce the behavior:
- Execute the following minimal example:
import torch
import torch_tensorrt
class TestModel(torch.nn.Module):
def forward(self, param1, additional_param = bool | None):
pass
compiled_model = torch_tensorrt.compile(
TestModel(),
ir="dynamo",
inputs=[torch.rand(1)],
kwarg_inputs={
"additional_param": True
},
)
- The result is
Traceback (most recent call last):
File "...\test_bug.py", line 8, in <module>
compiled_model = torch_tensorrt.compile(
File "...\lib\site-packages\torch_tensorrt\_compile.py", line 284, in compile
torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs)
File "...\lib\site-packages\torch_tensorrt\dynamo\utils.py", line 272, in prepare_inputs
torchtrt_input = prepare_inputs(
File "...\lib\site-packages\torch_tensorrt\dynamo\utils.py", line 280, in prepare_inputs
raise ValueError(
ValueError: Invalid input type <class 'bool'> encountered in the dynamo_compile input parsing. Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}
Expected behavior
The minimal example should compile fine. Any values in addition to torch tensors in both - inputs and kwarg_inputs - should IMHO be accepted. It would additionally be nice if the documentation would be a bit more verbose about this IMHO important topic of how inputs will be treated by the compiler and what will happen at runtime of the compiled model.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
I am sorry, I do not know a canonical way of "turning on debug messages" in python. I do not know how this translates into something actionable.
- Torch-TensorRT Version (e.g. 1.0.0) / PyTorch Version (e.g. 1.0):
tensorrt==10.7.0
tensorrt_cu12==10.7.0
tensorrt_cu12_bindings==10.7.0
tensorrt_cu12_libs==10.7.0
torch==2.6.0+cu124
torch_tensorrt==2.6.0+cu124
- CPU Architecture: Intel x86_64
- OS (e.g., Linux): Windows 10
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Python version: Python 3.10.16
- CUDA version: 12.4
Activity