Skip to content

Commit 9ae776d

Browse files
committed
Add some logging to how we handle dynamic shapes
1 parent b7ae84f commit 9ae776d

File tree

3 files changed

+36
-14
lines changed

3 files changed

+36
-14
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch._dynamo.backends.common import aot_autograd
1111
from torch._dynamo.utils import detect_fake_mode
1212
from torch._functorch.aot_autograd import aot_export_joint_simple
13+
1314
from torch_tensorrt._utils import is_tegra_platform
1415
from torch_tensorrt.dynamo import CompilationSettings
1516
from torch_tensorrt.dynamo._compiler import compile_module
@@ -118,9 +119,10 @@ def _pretraced_backend(
118119
fake_mode = detect_fake_mode(sample_inputs)
119120

120121
# Place backend tracing within FakeTensor context allowing nonfake Tensors
121-
with unittest.mock.patch.object(
122-
fake_mode, "allow_non_fake_inputs", True
123-
), fake_mode:
122+
with (
123+
unittest.mock.patch.object(fake_mode, "allow_non_fake_inputs", True),
124+
fake_mode,
125+
):
124126
repair_input_aliasing(gm, settings)
125127

126128
# Remove sym_int placeholders and inputs
@@ -170,7 +172,7 @@ def _pretraced_backend(
170172
engine_cache=engine_cache,
171173
)
172174
return trt_compiled
173-
except (AssertionError, RuntimeError):
175+
except (AssertionError, RuntimeError, TypeError):
174176
if not settings.pass_through_build_failures:
175177
logger.warning(
176178
"TRT conversion failed on the subgraph. See trace above. "

py/torch_tensorrt/dynamo/partitioning/common.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from torch._subclasses.fake_tensor import FakeTensor
66
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
7+
78
from torch_tensorrt._Input import Input
89
from torch_tensorrt.dynamo.utils import contains_sym_int, extract_var_range_info
910

@@ -26,11 +27,21 @@ def construct_dynamic_input(
2627
min_shape = []
2728
opt_shape = []
2829
max_shape = []
29-
for dim in input_shape:
30+
for d, dim in enumerate(input_shape):
3031
if isinstance(dim, torch.SymInt):
3132
min_max_opt = extract_var_range_info(dim)
33+
if min_max_opt["max"] is None:
34+
logger.warning(
35+
f"Dynamic input {name} (shape: {input_shape}) has no max bound for dim {d}, attempting to use a sane default (max: min({min_max_opt['min']}) * 128). Please set an upper bound using torch._dynamo.mark_dynamic or torch.export.Dim"
36+
)
37+
min_max_opt["max"] = min_max_opt["min"] * 128
3238
min_shape.append(min_max_opt["min"])
3339
# if opt not exist, set it to the mean of min and max
40+
if min_max_opt["opt"] is None:
41+
logger.info(
42+
f"Dynamic input {name} (shape: {input_shape}) has no opt target i.e. which shape to specialize for, for dim {d}, attempting to use a sane default (opt: min({min_max_opt['min']}) + max({min_max_opt['max']}) / 2). If you want to specialized further, use torch_tensorrt.compile"
43+
)
44+
min_max_opt["opt"] = int(min_max_opt["min"] + min_max_opt["max"] / 2)
3445
opt_shape.append(
3546
min_max_opt.get("opt", int(min_max_opt["min"] + min_max_opt["max"] / 2))
3647
)
@@ -61,7 +72,10 @@ def get_input(
6172
"""
6273
if contains_sym_int(input_shape):
6374
return construct_dynamic_input(
64-
input_shape, dtype, name=name, is_shape_tensor=is_shape_tensor
75+
input_shape,
76+
dtype,
77+
name=name,
78+
is_shape_tensor=is_shape_tensor,
6579
)
6680
else:
6781
return Input(

py/torch_tensorrt/dynamo/utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
import torch
2727
from torch._subclasses.fake_tensor import FakeTensor
2828
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
29+
from torch.utils._sympy.numbers import int_oo
30+
31+
from packaging import version
2932
from torch_tensorrt._Device import Device
3033
from torch_tensorrt._enums import dtype
3134
from torch_tensorrt._features import ENABLED_FEATURES
@@ -36,8 +39,6 @@
3639
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
3740
from torch_tensorrt.dynamo._settings import CompilationSettings
3841

39-
from packaging import version
40-
4142
from .types import TRTDataType
4243

4344
logger = logging.getLogger(__name__)
@@ -424,7 +425,11 @@ def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, int]:
424425
or expr.xreplace(shape_env.var_to_val)
425426
)
426427
assert var_range, var_val
427-
min_val, max_val = int(var_range.lower), int(var_range.upper)
428+
min_val, max_val = (
429+
int(var_range.lower),
430+
int(var_range.upper) if var_range.upper != int_oo else None,
431+
)
432+
428433
# Torchdynamo 0/1 specialization outlier
429434
min_val = 1 if min_val == 2 else min_val
430435
min_max_opt = {}
@@ -699,8 +704,9 @@ def check_module_output(
699704
arg_inputs: Any,
700705
kwarg_inputs: Any = None,
701706
) -> bool:
702-
old_outputs, new_outputs = refitted_module(*arg_inputs), new_module(
703-
*arg_inputs, **kwarg_inputs
707+
old_outputs, new_outputs = (
708+
refitted_module(*arg_inputs),
709+
new_module(*arg_inputs, **kwarg_inputs),
704710
)
705711
if type(old_outputs) != type(new_outputs):
706712
logger.warning("The output types are different. Output check is skipped.")
@@ -803,9 +809,9 @@ def copy_metadata(match_and_replacements: List[Any]) -> None:
803809
"""
804810
for match_and_replacement in match_and_replacements:
805811
anchor_node = match_and_replacement.nodes_map[match_and_replacement.anchor]
806-
assert (
807-
len(match_and_replacement.replacements) == 1
808-
), "Found more than 1 replacements for the anchor node."
812+
assert len(match_and_replacement.replacements) == 1, (
813+
"Found more than 1 replacements for the anchor node."
814+
)
809815
replacement_node = match_and_replacement.replacements[0]
810816
replacement_node.meta = anchor_node.meta
811817

0 commit comments

Comments
 (0)