Skip to content

🐛 [Bug] Torch-TensorRT runs out of memory when compiling a small PointNet model with otherwise low memory requirements #1854

Closed
@airalcorn2

Description

Bug Description

When trying to compile the small PointNet model below, Torch-TensorRT runs out of memory on a GeForce RTX 3080. The model has fairly low memory requirements: it's only 892,677 parameters (~778MiB) and a forward pass of a tensor with shape (1, 50000, 3) only uses ~1822MiB total, so it's not clear why compiling the model takes so much memory.

To Reproduce

Steps to reproduce the behavior:

import torch
import torch_tensorrt

from torch import nn


class PointNet(nn.Module):
    def __init__(
        self,
        n_classes=5,
        first_mlp_dimensions=[64, 64],
        second_mlp_dimensions=[64, 128, 1024],
        segmentation_mlp_dimensions=[512, 256, 128, 128],
    ):
        super().__init__()

        in_dimensions = 3
        first_mlp_layers = []
        for out_dimensions in first_mlp_dimensions:
            first_mlp_layers.extend(
                [
                    nn.Linear(in_dimensions, out_dimensions),
                    nn.LayerNorm(out_dimensions),
                    nn.ReLU(),
                ]
            )
            in_dimensions = out_dimensions

        self.first_mlp = nn.Sequential(*first_mlp_layers)

        second_mlp_layers = []
        for out_dimensions in second_mlp_dimensions:
            second_mlp_layers.extend(
                [
                    nn.Linear(in_dimensions, out_dimensions),
                    nn.LayerNorm(out_dimensions),
                    nn.ReLU(),
                ]
            )
            in_dimensions = out_dimensions

        self.second_mlp = nn.Sequential(*second_mlp_layers)

        in_dimensions = first_mlp_dimensions[-1] + second_mlp_dimensions[-1]
        segmentation_layers = []
        for out_dimensions in segmentation_mlp_dimensions:
            segmentation_layers.extend(
                [
                    nn.Linear(in_dimensions, out_dimensions),
                    nn.LayerNorm(out_dimensions),
                    nn.ReLU(),
                ]
            )
            in_dimensions = out_dimensions

        self.segmentation_mlp = nn.Sequential(*segmentation_layers)

        self.classifier = nn.Linear(in_dimensions, n_classes)

    def forward(self, points: torch.Tensor) -> torch.Tensor:
        _, n_points, _ = points.shape

        first_mlp_features = self.first_mlp(points)
        second_mlp_features = self.second_mlp(first_mlp_features)

        global_features = second_mlp_features.max(dim=1)[0]
        global_features = global_features.unsqueeze(1).expand(-1, n_points, -1)

        concatenated_features = torch.cat([first_mlp_features, global_features], dim=-1)

        segmentation_features = self.segmentation_mlp(concatenated_features)

        preds = self.classifier(segmentation_features)

        return preds


def do_forward_pass(model, P, device):
    with torch.no_grad():
        points = torch.rand(1, P, 3)
        _ = model(points.to(device))

    # nvidia-smi --> 1822MiB.


def main():
    device = torch.device("cuda:0")
    model = PointNet().to(device)
    model.eval()
    # nvidia-smi --> 778MiB.
    print(model)
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    # 892,677 parameters.
    print(f"Parameters: {n_params}")

    P = 50000
    inputs = [torch_tensorrt.Input((1, P, 3))]
    enabled_precisions = {torch.float}
    # Out of memory.
    trt_ts_module = torch_tensorrt.compile(
        model, inputs=inputs, enabled_precisions=enabled_precisions
    )


if __name__ == "__main__":
    main()

Expected behavior

Compile without error.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 1.3.0
  • PyTorch Version (e.g. 1.0): 1.13.1+cu117
  • CPU Architecture: i7-12800H
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.8.16
  • CUDA version: 11.7
  • GPU models and configuration: GeForce RTX 3080 Ti
  • Any other relevant information:

Additional context

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions