Skip to content

Commit 4f0bb6f

Browse files
authored
fix: Fix CI issues due to unintended fake tensor creation in torch.compile tests (#3416)
1 parent 9b78101 commit 4f0bb6f

File tree

1 file changed

+41
-34
lines changed

1 file changed

+41
-34
lines changed

py/torch_tensorrt/dynamo/utils.py

+41-34
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import tensorrt as trt
1313
import torch
1414
from torch._subclasses.fake_tensor import FakeTensor
15+
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
1516
from torch_tensorrt._Device import Device
1617
from torch_tensorrt._enums import dtype
1718
from torch_tensorrt._features import ENABLED_FEATURES
@@ -256,48 +257,54 @@ def prepare_inputs(
256257
inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any],
257258
disable_memory_format_check: bool = False,
258259
) -> Any:
259-
if inputs is None:
260-
return None
261-
262-
elif isinstance(inputs, Input):
263-
return inputs
260+
"""
261+
We take a nested group of torch.Tensors or scalars and convert them into torchtrt.Input's
262+
"""
263+
# Any tensors created inside this call will be FakeTensors if it's inside a torch.compile session
264+
# So, we disable fake mode temporarily.
265+
with unset_fake_temporarily():
266+
if inputs is None:
267+
return None
264268

265-
elif isinstance(inputs, (torch.Tensor, int, float, bool)):
266-
return Input.from_tensor(
267-
torch.tensor(inputs),
268-
disable_memory_format_check=disable_memory_format_check,
269-
)
269+
elif isinstance(inputs, Input):
270+
return inputs
270271

271-
elif isinstance(inputs, (list, tuple)):
272-
torchtrt_input_list = []
273-
for input_obj in inputs:
274-
torchtrt_input = prepare_inputs(
275-
input_obj, disable_memory_format_check=disable_memory_format_check
272+
elif isinstance(inputs, (torch.Tensor, int, float, bool)):
273+
return Input.from_tensor(
274+
torch.tensor(inputs),
275+
disable_memory_format_check=disable_memory_format_check,
276276
)
277-
torchtrt_input_list.append(torchtrt_input)
278-
279-
return (
280-
torchtrt_input_list
281-
if isinstance(inputs, list)
282-
else tuple(torchtrt_input_list)
283-
)
284277

285-
elif isinstance(inputs, dict):
286-
torchtrt_inputs_dict: Dict[Any, Any] = dict()
278+
elif isinstance(inputs, (list, tuple)):
279+
torchtrt_input_list = []
280+
for input_obj in inputs:
281+
torchtrt_input = prepare_inputs(
282+
input_obj, disable_memory_format_check=disable_memory_format_check
283+
)
284+
torchtrt_input_list.append(torchtrt_input)
287285

288-
for key, input_obj in inputs.items():
289-
torchtrt_input = prepare_inputs(
290-
input_obj, disable_memory_format_check=disable_memory_format_check
286+
return (
287+
torchtrt_input_list
288+
if isinstance(inputs, list)
289+
else tuple(torchtrt_input_list)
291290
)
292-
torchtrt_inputs_dict[key] = torchtrt_input
293291

294-
return torchtrt_inputs_dict
292+
elif isinstance(inputs, dict):
293+
torchtrt_inputs_dict: Dict[Any, Any] = dict()
295294

296-
else:
297-
raise ValueError(
298-
f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. "
299-
+ "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}"
300-
)
295+
for key, input_obj in inputs.items():
296+
torchtrt_input = prepare_inputs(
297+
input_obj, disable_memory_format_check=disable_memory_format_check
298+
)
299+
torchtrt_inputs_dict[key] = torchtrt_input
300+
301+
return torchtrt_inputs_dict
302+
303+
else:
304+
raise ValueError(
305+
f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. "
306+
+ "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}"
307+
)
301308

302309

303310
def parse_complex_tensor_structs(

0 commit comments

Comments
 (0)