Skip to content

❓ [Question] Using dynamic shapes with FX frontend #2486

Closed as not planned
Closed as not planned
@HolyWu

Description

@HolyWu

I tried to use dynamic shapes in FX path with the following codes. It seems that the input_specs argument passed to LowerSetting has no effect and TRT gives an error message.

import torch
import torch.nn as nn
from torch_tensorrt.fx import InputTensorSpec, LowerSetting
from torch_tensorrt.fx.lower import Lowerer
from torch_tensorrt.fx.utils import LowerPrecision


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(1, 20, 5), nn.PReLU())

    def forward(self, input):
        return self.conv(input)


with torch.inference_mode():
    device = torch.device("cuda")
    mod = MyModule().eval().to(device).half()

    lower_setting = LowerSetting(
        lower_precision=LowerPrecision.FP16,
        min_acc_module_size=1,
        input_specs=[
            InputTensorSpec(
                shape=(1, 1, -1, -1),
                dtype=torch.half,
                device=device,
                shape_ranges=[((1, 1, 16, 16), (1, 1, 32, 32), (1, 1, 64, 64))],
            )
        ],
        dynamic_batch=False,
    )
    lowerer = Lowerer.create(lower_setting=lower_setting)
    mod_trt = lowerer(mod, [torch.rand((1, 1, 16, 16), dtype=torch.half, device=device)])

    print(mod_trt(torch.rand((1, 1, 16, 16), dtype=torch.half, device=device)).shape)
    print(mod_trt(torch.rand((1, 1, 32, 32), dtype=torch.half, device=device)).shape)
WARNING:torch_tensorrt.fx.tracer.acc_tracer.acc_tracer:MyModule__AccRewrittenModule does not have attribute _compiled_call_impl
WARNING:torch_tensorrt.fx.tracer.acc_tracer.acc_tracer:Sequential__AccRewrittenModule does not have attribute _compiled_call_impl
WARNING:torch_tensorrt.fx.tracer.acc_tracer.acc_tracer:Conv2d__AccRewrittenModule does not have attribute _compiled_call_impl
WARNING:torch_tensorrt.fx.tracer.acc_tracer.acc_tracer:PReLU__AccRewrittenModule does not have attribute _compiled_call_impl
C:\Python311\Lib\site-packages\torch\overrides.py:110: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
  torch.has_cuda,
C:\Python311\Lib\site-packages\torch\overrides.py:111: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'
  torch.has_cudnn,
C:\Python311\Lib\site-packages\torch\overrides.py:117: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
  torch.has_mps,
C:\Python311\Lib\site-packages\torch\overrides.py:118: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'
  torch.has_mkldnn,
WARNING:torch_tensorrt.fx.tracer.acc_tracer.acc_tracer:GraphModule.__new__.<locals>.GraphModuleImpl__AccRewrittenModule does not have attribute _compiled_call_impl
WARNING:torch_tensorrt.fx.tracer.acc_tracer.acc_tracer:Module__AccRewrittenModule does not have attribute _compiled_call_impl
WARNING:torch_tensorrt.fx.tracer.acc_tracer.acc_tracer:Module__AccRewrittenModule does not have attribute _compiled_call_impl
WARNING:torch_tensorrt.fx.tracer.acc_tracer.acc_tracer:Module__AccRewrittenModule does not have attribute _compiled_call_impl
INFO:torch_tensorrt.fx.passes.pass_utils:== Log pass <function fuse_permute_matmul at 0x000001E8B08F0E00> before/after graph to C:\Users\HOLYWU~1\AppData\Local\Temp\tmpgbz4qw6c, before/after are the same = True, time elapsed = 0:00:00.026858
INFO:torch_tensorrt.fx.passes.pass_utils:== Log pass <function fuse_permute_linear at 0x000001E8B08F0B80> before/after graph to C:\Users\HOLYWU~1\AppData\Local\Temp\tmpp8c1a1dw, before/after are the same = True, time elapsed = 0:00:00.000981
INFO:torch_tensorrt.fx.passes.pass_utils:== Log pass <function fix_clamp_numerical_limits_to_fp16 at 0x000001E8B08F1440> before/after graph to C:\Users\HOLYWU~1\AppData\Local\Temp\tmp43sia5pv, before/after are the same = True, time elapsed = 0:00:00

Supported node types in the model:
acc_ops.conv2d: ((), {'input': torch.float16, 'weight': torch.float16, 'bias': torch.float16})

Unsupported node types in the model:
acc_ops.prelu: ((), {'input': torch.float16, 'weight': torch.float16})

Got 1 acc subgraphs and 1 non-acc subgraphs
INFO:torch_tensorrt.fx.passes.lower_pass_manager_builder:Now lowering submodule _run_on_acc_0
INFO:torch_tensorrt.fx.lower:split_name=_run_on_acc_0, input_specs=[InputTensorSpec(shape=torch.Size([1, 1, 16, 16]), dtype=torch.float16, device=device(type='cuda', index=0), shape_ranges=[], has_batch_dim=True)]
INFO:torch_tensorrt.fx.lower:Timing cache is used!
INFO:torch_tensorrt.fx.fx2trt:TRT INetwork construction elapsed time: 0:00:00.001014
INFO:torch_tensorrt.fx.fx2trt:Build TRT engine elapsed time: 0:00:00.993050
INFO:torch_tensorrt.fx.passes.lower_pass_manager_builder:Lowering submodule _run_on_acc_0 elapsed time 0:00:05.996300
torch.Size([1, 20, 12, 12])
[11/25/2023-13:55:00] [TRT] [E] 3: [executionContext.cpp::nvinfer1::rt::ExecutionContext::validateInputBindings::2082] Error Code 3: API Usage Error (Parameter check failed at: executionContext.cpp::nvinfer1::rt::ExecutionContext::validateInputBindings::2082, condition: profileMaxDims.d[i] >= dimensions.d[i]. Supplied binding dimension [1,1,32,32] for bindings[0] exceed min ~ max range at index 2, maximum dimension in profile is 16, minimum dimension in profile is 16, but supplied dimension is 32.
)
torch.Size([1, 20, 12, 12])

Environment

  • PyTorch Version (e.g., 1.0): 2.1.1+cu121
  • CPU Architecture: x86-64
  • OS (e.g., Linux): Windows 11
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.11.6
  • CUDA version:
  • GPU models and configuration: GeForce RTX 3050
  • Any other relevant information:

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions