Skip to content

❓ [Question] Using dynamic shapes with FX frontend #2486

Closed as not planned
Closed as not planned
@HolyWu

Description

I tried to use dynamic shapes in FX path with the following codes. It seems that the input_specs argument passed to LowerSetting has no effect and TRT gives an error message.

import torch
import torch.nn as nn
from torch_tensorrt.fx import InputTensorSpec, LowerSetting
from torch_tensorrt.fx.lower import Lowerer
from torch_tensorrt.fx.utils import LowerPrecision


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(1, 20, 5), nn.PReLU())

    def forward(self, input):
        return self.conv(input)


with torch.inference_mode():
    device = torch.device("cuda")
    mod = MyModule().eval().to(device).half()

    lower_setting = LowerSetting(
        lower_precision=LowerPrecision.FP16,
        min_acc_module_size=1,
        input_specs=[
            InputTensorSpec(
                shape=(1, 1, -1, -1),
                dtype=torch.half,
                device=device,
                shape_ranges=[((1, 1, 16, 16), (1, 1, 32, 32), (1, 1, 64, 64))],
            )
        ],
        dynamic_batch=False,
    )
    lowerer = Lowerer.create(lower_setting=lower_setting)
    mod_trt = lowerer(mod, [torch.rand((1, 1, 16, 16), dtype=torch.half, device=device)])

    print(mod_trt(torch.rand((1, 1, 16, 16), dtype=torch.half, device=device)).shape)
    print(mod_trt(torch.rand((1, 1, 32, 32), dtype=torch.half, device=device)).shape)
WARNING:torch_tensorrt.fx.tracer.acc_tracer.acc_tracer:MyModule__AccRewrittenModule does not have attribute _compiled_call_impl
WARNING:torch_tensorrt.fx.tracer.acc_tracer.acc_tracer:Sequential__AccRewrittenModule does not have attribute _compiled_call_impl
WARNING:torch_tensorrt.fx.tracer.acc_tracer.acc_tracer:Conv2d__AccRewrittenModule does not have attribute _compiled_call_impl
WARNING:torch_tensorrt.fx.tracer.acc_tracer.acc_tracer:PReLU__AccRewrittenModule does not have attribute _compiled_call_impl
C:\Python311\Lib\site-packages\torch\overrides.py:110: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
  torch.has_cuda,
C:\Python311\Lib\site-packages\torch\overrides.py:111: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'
  torch.has_cudnn,
C:\Python311\Lib\site-packages\torch\overrides.py:117: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
  torch.has_mps,
C:\Python311\Lib\site-packages\torch\overrides.py:118: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'
  torch.has_mkldnn,
WARNING:torch_tensorrt.fx.tracer.acc_tracer.acc_tracer:GraphModule.__new__.<locals>.GraphModuleImpl__AccRewrittenModule does not have attribute _compiled_call_impl
WARNING:torch_tensorrt.fx.tracer.acc_tracer.acc_tracer:Module__AccRewrittenModule does not have attribute _compiled_call_impl
WARNING:torch_tensorrt.fx.tracer.acc_tracer.acc_tracer:Module__AccRewrittenModule does not have attribute _compiled_call_impl
WARNING:torch_tensorrt.fx.tracer.acc_tracer.acc_tracer:Module__AccRewrittenModule does not have attribute _compiled_call_impl
INFO:torch_tensorrt.fx.passes.pass_utils:== Log pass <function fuse_permute_matmul at 0x000001E8B08F0E00> before/after graph to C:\Users\HOLYWU~1\AppData\Local\Temp\tmpgbz4qw6c, before/after are the same = True, time elapsed = 0:00:00.026858
INFO:torch_tensorrt.fx.passes.pass_utils:== Log pass <function fuse_permute_linear at 0x000001E8B08F0B80> before/after graph to C:\Users\HOLYWU~1\AppData\Local\Temp\tmpp8c1a1dw, before/after are the same = True, time elapsed = 0:00:00.000981
INFO:torch_tensorrt.fx.passes.pass_utils:== Log pass <function fix_clamp_numerical_limits_to_fp16 at 0x000001E8B08F1440> before/after graph to C:\Users\HOLYWU~1\AppData\Local\Temp\tmp43sia5pv, before/after are the same = True, time elapsed = 0:00:00

Supported node types in the model:
acc_ops.conv2d: ((), {'input': torch.float16, 'weight': torch.float16, 'bias': torch.float16})

Unsupported node types in the model:
acc_ops.prelu: ((), {'input': torch.float16, 'weight': torch.float16})

Got 1 acc subgraphs and 1 non-acc subgraphs
INFO:torch_tensorrt.fx.passes.lower_pass_manager_builder:Now lowering submodule _run_on_acc_0
INFO:torch_tensorrt.fx.lower:split_name=_run_on_acc_0, input_specs=[InputTensorSpec(shape=torch.Size([1, 1, 16, 16]), dtype=torch.float16, device=device(type='cuda', index=0), shape_ranges=[], has_batch_dim=True)]
INFO:torch_tensorrt.fx.lower:Timing cache is used!
INFO:torch_tensorrt.fx.fx2trt:TRT INetwork construction elapsed time: 0:00:00.001014
INFO:torch_tensorrt.fx.fx2trt:Build TRT engine elapsed time: 0:00:00.993050
INFO:torch_tensorrt.fx.passes.lower_pass_manager_builder:Lowering submodule _run_on_acc_0 elapsed time 0:00:05.996300
torch.Size([1, 20, 12, 12])
[11/25/2023-13:55:00] [TRT] [E] 3: [executionContext.cpp::nvinfer1::rt::ExecutionContext::validateInputBindings::2082] Error Code 3: API Usage Error (Parameter check failed at: executionContext.cpp::nvinfer1::rt::ExecutionContext::validateInputBindings::2082, condition: profileMaxDims.d[i] >= dimensions.d[i]. Supplied binding dimension [1,1,32,32] for bindings[0] exceed min ~ max range at index 2, maximum dimension in profile is 16, minimum dimension in profile is 16, but supplied dimension is 32.
)
torch.Size([1, 20, 12, 12])

Environment

  • PyTorch Version (e.g., 1.0): 2.1.1+cu121
  • CPU Architecture: x86-64
  • OS (e.g., Linux): Windows 11
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.11.6
  • CUDA version:
  • GPU models and configuration: GeForce RTX 3050
  • Any other relevant information:

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions