Skip to content

🐛 [Bug] Encountered bug when using Torch-TensorRT #3406

Open
@cehongwang

Description

Bug Description

When using CudaGraph, if there is a graph break resulting a pytorch subgraph, and if the input has a nested dictionary, cuda graph breaks.
In the following example, the number of input is 2 in the forward function, but 3 in the graph module, which gives an error.

To Reproduce

import torch
import torch_tensorrt

class TestModel(torch.nn.Module):
    def forward(self, x, additional_param: dict):
        x = x + additional_param['y']
        return x  * additional_param['z'] + 5

device = "cuda:0"
inputs=(torch.rand(1).to(device),)
kwarg_inputs={
    "additional_param": {
        'y':torch.rand(1).to(device), 
        'z':torch.rand(1).to(device), 
    }
}
model = TestModel().to(device)
compiled_model = torch_tensorrt.compile(
    model,
    ir="dynamo",
    arg_inputs=inputs,
    kwarg_inputs=kwarg_inputs,
    min_block_size=1,
    torch_executed_ops={"torch.ops.aten.mul.Tensor"}
)

with torch_tensorrt.runtime.enable_cudagraphs(
    compiled_model
) as cudagraphs_module:
    cudagraphs_module(*inputs, **kwarg_inputs)

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): nightly
  • PyTorch Version (e.g. 1.0): nightly
  • CPU Architecture: x86
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • GPU models and configuration: A40

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions