[Bug] monai.VNet compilation fails due to 10-dimensional expand causing rank error #3330
Open
Description
When compiling a VNet model with Torch-TensorRT, an expand operation results in a 10D tensor, exceeding the 8D limit. This causes conversion failure when setting IShuffleLayer.reshape_dims
.
Steps to Reproduce:
from monai.networks.nets import VNet
import torch
import torch_tensorrt
device = "cuda:0"
model = VNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
act="relu"
).to(device).half().eval()
input_tensor = torch.randn(1, 1, 128, 128, 128, device=device).half()
backend = "torch_tensorrt"
# Compile the model with Torch-TensorRT backend
model = torch.compile(
model,
backend=backend,
options={
"use_python_runtime": False,
"enabled_precisions": {torch.float16},
"truncate_double": True,
"debug": True,
"min_block_size": 1,
},
dynamic=False,
)
with torch.no_grad():
output = model(input_tensor)
print(output)
Running this code triggers a TypeError
or ValueError
related to an unsupported 10D shape.
During the compilation, the graph contains a node:
%expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%arg1_1, [1, 16, 1, 1, 1, 1, 1, 128, 128, 128]), kwargs = {})
This creates a 10D shape (1, 16, 1, 1, 1, 1, 1, 128, 128, 128)
. Subsequent operations (like permute
) rely on this shape, causing the TensorRT converter to fail, as TensorRT's IShuffleLayer
only supports up to 8D tensors.
Error Message:
ValueError: Input length 10. Max expected length is 8
or
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node expand (kind: aten.expand.default, args: ('arg1_1 <Node>', ['1 <int>', '16 <int>', '1 <int>', '1 <int>', '1 <int>', '1 <int>', '1 <int>', '128 <int>', '128 <int>', '128 <int>']))
Traceback (most recent call last):
File "/opt/torch_tensorrt/run_clara_vnet.py", line 43, in <module>
output = model(input_tensor)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 554, in _fn
return fn(*args, **kwargs)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1428, in __call__
return self._torchdynamo_orig_callable(
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1211, in __call__
result = self._inner_convert(
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 548, in __call__
return _compile(
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 981, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 707, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 742, in _compile_inner
out_code = transform_code_object(code, transform)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1337, in transform_code_object
transformations(instructions, code_options)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 232, in _fn
return fn(*args, **kwargs)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 661, in transform
tracer.run()
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2909, in run
super().run()
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in run
while self.step():
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1027, in step
self.dispatch_table[inst.opcode](self, inst)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3100, in RETURN_VALUE
self._return(inst)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3085, in _return
self.output.compile_subgraph(
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1140, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1411, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1458, in call_user_compiler
return self._call_user_compiler(gm)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1507, in _call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1488, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/__init__.py", line 2323, in __call__
return self.compiler_fn(model_, inputs_, **self.kwargs)
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/backend/backends.py", line 44, in torch_tensorrt_backend
return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/backend/backends.py", line 52, in aot_torch_tensorrt_aten_backend
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/backend/backends.py", line 114, in _pretraced_backend
trt_compiled = compile_module(
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/_compiler.py", line 494, in compile_module
trt_module = convert_module(
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 141, in convert_module
interpreter_result = interpret_module_to_result(
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 120, in interpret_module_to_result
interpreter_result = interpreter.run()
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 620, in run
self._construct_trt_network_def()
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 351, in _construct_trt_network_def
super().run()
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/fx/interpreter.py", line 167, in run
self.env[node] = self.run_node(node)
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 686, in run_node
trt_node: torch.fx.Node = super().run_node(n)
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/fx/interpreter.py", line 228, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 795, in call_function
return converter(self.ctx, target, args, kwargs, self._cur_node_name)
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/converter_utils.py", line 533, in convert_with_type_enforcement
return func(ctx, target, new_args, new_kwargs, name)
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 1170, in aten_ops_expand
return impl.slice.expand(
File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py", line 232, in expand
input_t = prepend_ones(
File "/opt/torch_tensorrt/py/torch_tensorrt/fx/converters/converter_utils.py", line 370, in prepend_ones
layer.reshape_dims = (1,) * num_prepend_ones + tuple(tensor.shape)
torch._dynamo.exc.BackendCompilerFailed: backend='torch_tensorrt' raised:
TypeError: (): incompatible function arguments. The following argument types are supported:
1. (arg0: tensorrt_bindings.tensorrt.IShuffleLayer, arg1: tensorrt_bindings.tensorrt.Dims) -> None
Invoked with: <tensorrt_bindings.tensorrt.IShuffleLayer object at 0x7f3f4e5096f0>, (1, 1, 1, 1, 1, 1, 1, 128, 128, 128)
While executing %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%arg1_1, [1, 16, 1, 1, 1, 1, 1, 128, 128, 128]), kwargs = {_itensor_to_tensor_meta: {<tensorrt_bindings.tensorrt.ITensor object at 0x7f3f4e37abb0>: ((1, 1, 128, 128, 128), torch.float16, False, (2097152, 2097152, 16384, 128, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f3f4e3531f0>: ((16, 1, 5, 5, 5), torch.float16, True, (125, 125, 25, 5, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f3f4e5098b0>: ((1, 16, 128, 128, 128), torch.float16, False, (33554432, 2097152, 16384, 128, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f3f4e508bf0>: ((16,), torch.float16, True, (1,), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f3f4e508cf0>: ((16,), torch.float16, True, (1,), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f3f4e5093f0>: ((16,), torch.float16, False, (1,), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f3f4e509b30>: ((16,), torch.float16, False, (1,), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f3f4e58b4f0>: ((1, 16, 128, 128, 128), torch.float16, False, (33554432, 2097152, 16384, 128, 1), torch.contiguous_format, False, {})}})
Metadata
Assignees
Labels
No labels