Skip to content

🐛 [Bug] Cannot save models using ExportedProgram if the model has weighted layers #2341

Closed
@peri044

Description

Bug Description

If you have weighted layers in the graph, loading the model via ExportedProgram and running inference fails as the weights and inputs are on different device.

To Reproduce

import torch
import torch_tensorrt
import unittest

class MyModule(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
            self.relu = torch.nn.ReLU()

        def forward(self, x):
            conv = self.conv(x)
            relu = self.relu(conv)
            mul = relu*0.5
            return mul

input = torch.randn((1, 3, 224, 224), dtype=torch.float).to("cuda")
model = MyModule().eval().cuda()
from torch._export import export
from torch_tensorrt.dynamo.lowering import get_decompositions

with unittest.mock.patch(
    "torch._export.DECOMP_TABLE", get_decompositions(True)
):
    trt_exp_program = export(model, tuple([input]))

torch._export.save(trt_exp_program, "./trt.ep")
deserialized_prog = torch._export.load("./trt.ep")
out_pyt = model(input)
out_trt_ser = deserialized_prog(input).cuda()

Error message:

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same 

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Assignees

Labels

Blocked [PyTorch]Issue is blocked by some limitation of PyTorchbugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions