|
12 | 12 | import tensorrt as trt
|
13 | 13 | import torch
|
14 | 14 | from torch._subclasses.fake_tensor import FakeTensor
|
| 15 | +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily |
15 | 16 | from torch_tensorrt._Device import Device
|
16 | 17 | from torch_tensorrt._enums import dtype
|
17 | 18 | from torch_tensorrt._features import ENABLED_FEATURES
|
@@ -256,48 +257,54 @@ def prepare_inputs(
|
256 | 257 | inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any],
|
257 | 258 | disable_memory_format_check: bool = False,
|
258 | 259 | ) -> Any:
|
259 |
| - if inputs is None: |
260 |
| - return None |
261 |
| - |
262 |
| - elif isinstance(inputs, Input): |
263 |
| - return inputs |
| 260 | + """ |
| 261 | + We take a nested group of torch.Tensors or scalars and convert them into torchtrt.Input's |
| 262 | + """ |
| 263 | + # Any tensors created inside this call will be FakeTensors if it's inside a torch.compile session |
| 264 | + # So, we disable fake mode temporarily. |
| 265 | + with unset_fake_temporarily(): |
| 266 | + if inputs is None: |
| 267 | + return None |
264 | 268 |
|
265 |
| - elif isinstance(inputs, (torch.Tensor, int, float, bool)): |
266 |
| - return Input.from_tensor( |
267 |
| - torch.tensor(inputs), |
268 |
| - disable_memory_format_check=disable_memory_format_check, |
269 |
| - ) |
| 269 | + elif isinstance(inputs, Input): |
| 270 | + return inputs |
270 | 271 |
|
271 |
| - elif isinstance(inputs, (list, tuple)): |
272 |
| - torchtrt_input_list = [] |
273 |
| - for input_obj in inputs: |
274 |
| - torchtrt_input = prepare_inputs( |
275 |
| - input_obj, disable_memory_format_check=disable_memory_format_check |
| 272 | + elif isinstance(inputs, (torch.Tensor, int, float, bool)): |
| 273 | + return Input.from_tensor( |
| 274 | + torch.tensor(inputs), |
| 275 | + disable_memory_format_check=disable_memory_format_check, |
276 | 276 | )
|
277 |
| - torchtrt_input_list.append(torchtrt_input) |
278 |
| - |
279 |
| - return ( |
280 |
| - torchtrt_input_list |
281 |
| - if isinstance(inputs, list) |
282 |
| - else tuple(torchtrt_input_list) |
283 |
| - ) |
284 | 277 |
|
285 |
| - elif isinstance(inputs, dict): |
286 |
| - torchtrt_inputs_dict: Dict[Any, Any] = dict() |
| 278 | + elif isinstance(inputs, (list, tuple)): |
| 279 | + torchtrt_input_list = [] |
| 280 | + for input_obj in inputs: |
| 281 | + torchtrt_input = prepare_inputs( |
| 282 | + input_obj, disable_memory_format_check=disable_memory_format_check |
| 283 | + ) |
| 284 | + torchtrt_input_list.append(torchtrt_input) |
287 | 285 |
|
288 |
| - for key, input_obj in inputs.items(): |
289 |
| - torchtrt_input = prepare_inputs( |
290 |
| - input_obj, disable_memory_format_check=disable_memory_format_check |
| 286 | + return ( |
| 287 | + torchtrt_input_list |
| 288 | + if isinstance(inputs, list) |
| 289 | + else tuple(torchtrt_input_list) |
291 | 290 | )
|
292 |
| - torchtrt_inputs_dict[key] = torchtrt_input |
293 | 291 |
|
294 |
| - return torchtrt_inputs_dict |
| 292 | + elif isinstance(inputs, dict): |
| 293 | + torchtrt_inputs_dict: Dict[Any, Any] = dict() |
295 | 294 |
|
296 |
| - else: |
297 |
| - raise ValueError( |
298 |
| - f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. " |
299 |
| - + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" |
300 |
| - ) |
| 295 | + for key, input_obj in inputs.items(): |
| 296 | + torchtrt_input = prepare_inputs( |
| 297 | + input_obj, disable_memory_format_check=disable_memory_format_check |
| 298 | + ) |
| 299 | + torchtrt_inputs_dict[key] = torchtrt_input |
| 300 | + |
| 301 | + return torchtrt_inputs_dict |
| 302 | + |
| 303 | + else: |
| 304 | + raise ValueError( |
| 305 | + f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. " |
| 306 | + + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" |
| 307 | + ) |
301 | 308 |
|
302 | 309 |
|
303 | 310 | def parse_complex_tensor_structs(
|
|
0 commit comments