Skip to content

[Bug] monai.VNet compilation fails due to 10-dimensional expand causing rank error #3330

Open
@chohk88

Description

When compiling a VNet model with Torch-TensorRT, an expand operation results in a 10D tensor, exceeding the 8D limit. This causes conversion failure when setting IShuffleLayer.reshape_dims.

Steps to Reproduce:

from monai.networks.nets import VNet
import torch
import torch_tensorrt

device = "cuda:0"
model = VNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    act="relu"
).to(device).half().eval()

input_tensor = torch.randn(1, 1, 128, 128, 128, device=device).half()

backend = "torch_tensorrt"

# Compile the model with Torch-TensorRT backend
model = torch.compile(
    model,
    backend=backend,
    options={
        "use_python_runtime": False,
        "enabled_precisions": {torch.float16},
        "truncate_double": True,
        "debug": True,
        "min_block_size": 1,
    },
    dynamic=False,
)

with torch.no_grad():
    output = model(input_tensor)
print(output)

Running this code triggers a TypeError or ValueError related to an unsupported 10D shape.
During the compilation, the graph contains a node:

%expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%arg1_1, [1, 16, 1, 1, 1, 1, 1, 128, 128, 128]), kwargs = {})

This creates a 10D shape (1, 16, 1, 1, 1, 1, 1, 128, 128, 128). Subsequent operations (like permute) rely on this shape, causing the TensorRT converter to fail, as TensorRT's IShuffleLayer only supports up to 8D tensors.

Error Message:

ValueError: Input length 10. Max expected length is 8

or

DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node expand (kind: aten.expand.default, args: ('arg1_1 <Node>', ['1 <int>', '16 <int>', '1 <int>', '1 <int>', '1 <int>', '1 <int>', '1 <int>', '128 <int>', '128 <int>', '128 <int>']))
Traceback (most recent call last):
  File "/opt/torch_tensorrt/run_clara_vnet.py", line 43, in <module>
    output = model(input_tensor)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 554, in _fn
    return fn(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1428, in __call__
    return self._torchdynamo_orig_callable(
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1211, in __call__
    result = self._inner_convert(
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 548, in __call__
    return _compile(
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 981, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 707, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 742, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1337, in transform_code_object
    transformations(instructions, code_options)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 232, in _fn
    return fn(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 661, in transform
    tracer.run()
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2909, in run
    super().run()
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in run
    while self.step():
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1027, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3100, in RETURN_VALUE
    self._return(inst)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3085, in _return
    self.output.compile_subgraph(
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1140, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1411, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1458, in call_user_compiler
    return self._call_user_compiler(gm)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1507, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1488, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/__init__.py", line 2323, in __call__
    return self.compiler_fn(model_, inputs_, **self.kwargs)
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/backend/backends.py", line 44, in torch_tensorrt_backend
    return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/backend/backends.py", line 52, in aot_torch_tensorrt_aten_backend
    return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/backend/backends.py", line 114, in _pretraced_backend
    trt_compiled = compile_module(
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/_compiler.py", line 494, in compile_module
    trt_module = convert_module(
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 141, in convert_module
    interpreter_result = interpret_module_to_result(
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 120, in interpret_module_to_result
    interpreter_result = interpreter.run()
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 620, in run
    self._construct_trt_network_def()
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 351, in _construct_trt_network_def
    super().run()
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/fx/interpreter.py", line 167, in run
    self.env[node] = self.run_node(node)
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 686, in run_node
    trt_node: torch.fx.Node = super().run_node(n)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/fx/interpreter.py", line 228, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 795, in call_function
    return converter(self.ctx, target, args, kwargs, self._cur_node_name)
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/converter_utils.py", line 533, in convert_with_type_enforcement
    return func(ctx, target, new_args, new_kwargs, name)
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 1170, in aten_ops_expand
    return impl.slice.expand(
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py", line 232, in expand
    input_t = prepend_ones(
  File "/opt/torch_tensorrt/py/torch_tensorrt/fx/converters/converter_utils.py", line 370, in prepend_ones
    layer.reshape_dims = (1,) * num_prepend_ones + tuple(tensor.shape)
torch._dynamo.exc.BackendCompilerFailed: backend='torch_tensorrt' raised:
TypeError: (): incompatible function arguments. The following argument types are supported:
    1. (arg0: tensorrt_bindings.tensorrt.IShuffleLayer, arg1: tensorrt_bindings.tensorrt.Dims) -> None

Invoked with: <tensorrt_bindings.tensorrt.IShuffleLayer object at 0x7f3f4e5096f0>, (1, 1, 1, 1, 1, 1, 1, 128, 128, 128)

While executing %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%arg1_1, [1, 16, 1, 1, 1, 1, 1, 128, 128, 128]), kwargs = {_itensor_to_tensor_meta: {<tensorrt_bindings.tensorrt.ITensor object at 0x7f3f4e37abb0>: ((1, 1, 128, 128, 128), torch.float16, False, (2097152, 2097152, 16384, 128, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f3f4e3531f0>: ((16, 1, 5, 5, 5), torch.float16, True, (125, 125, 25, 5, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f3f4e5098b0>: ((1, 16, 128, 128, 128), torch.float16, False, (33554432, 2097152, 16384, 128, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f3f4e508bf0>: ((16,), torch.float16, True, (1,), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f3f4e508cf0>: ((16,), torch.float16, True, (1,), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f3f4e5093f0>: ((16,), torch.float16, False, (1,), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f3f4e509b30>: ((16,), torch.float16, False, (1,), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f3f4e58b4f0>: ((1, 16, 128, 128, 128), torch.float16, False, (33554432, 2097152, 16384, 128, 1), torch.contiguous_format, False, {})}})

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions