Open
Description
Bug Description
When using CudaGraph, if there is a graph break resulting a pytorch subgraph, and if the input has a nested dictionary, cuda graph breaks.
In the following example, the number of input is 2 in the forward function, but 3 in the graph module, which gives an error.
To Reproduce
import torch
import torch_tensorrt
class TestModel(torch.nn.Module):
def forward(self, x, additional_param: dict):
x = x + additional_param['y']
return x * additional_param['z'] + 5
device = "cuda:0"
inputs=(torch.rand(1).to(device),)
kwarg_inputs={
"additional_param": {
'y':torch.rand(1).to(device),
'z':torch.rand(1).to(device),
}
}
model = TestModel().to(device)
compiled_model = torch_tensorrt.compile(
model,
ir="dynamo",
arg_inputs=inputs,
kwarg_inputs=kwarg_inputs,
min_block_size=1,
torch_executed_ops={"torch.ops.aten.mul.Tensor"}
)
with torch_tensorrt.runtime.enable_cudagraphs(
compiled_model
) as cudagraphs_module:
cudagraphs_module(*inputs, **kwarg_inputs)
Expected behavior
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): nightly
- PyTorch Version (e.g. 1.0): nightly
- CPU Architecture: x86
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Build command you used (if compiling from source):
- GPU models and configuration: A40
Activity