|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import functools |
3 | 4 | import logging
|
4 | 5 | import unittest
|
5 | 6 | from typing import Any, Callable, Sequence
|
6 | 7 |
|
7 | 8 | import torch
|
8 | 9 | import torch._dynamo as td
|
| 10 | +from torch._dynamo.backends.common import aot_autograd |
9 | 11 | from torch._dynamo.utils import detect_fake_mode
|
10 | 12 | from torch._functorch.aot_autograd import aot_export_joint_simple
|
11 | 13 | from torch_tensorrt.dynamo import CompilationSettings
|
12 | 14 | from torch_tensorrt.dynamo._compiler import compile_module
|
13 | 15 | from torch_tensorrt.dynamo.lowering import (
|
14 | 16 | get_decompositions,
|
| 17 | + modify_reshape_complex_nodes, |
15 | 18 | post_lowering,
|
16 | 19 | remove_detach,
|
17 | 20 | remove_sym_nodes,
|
@@ -49,7 +52,25 @@ def aot_torch_tensorrt_aten_backend(
|
49 | 52 | gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
|
50 | 53 | ) -> torch.nn.Module:
|
51 | 54 | settings, engine_cache = parse_dynamo_kwargs(kwargs)
|
52 |
| - return _pretraced_backend(gm, sample_inputs, settings, engine_cache) |
| 55 | + if settings.use_aot_joint_export: |
| 56 | + return _pretraced_backend(gm, sample_inputs, settings, engine_cache) |
| 57 | + logger.debug("Wrapping the backend with aot_autograd\n") |
| 58 | + _pretraced_backend_autograd = functools.partial( |
| 59 | + _pretraced_backend, settings=settings, engine_cache=engine_cache |
| 60 | + ) |
| 61 | + settings_aot_autograd = {} |
| 62 | + settings_aot_autograd["decompostions"] = get_decompositions( |
| 63 | + settings.enable_experimental_decompositions |
| 64 | + ) |
| 65 | + # This is added since detach lowering leads to alias nodes |
| 66 | + # Error - View operation returned a tensor that is the same as the input base tensor |
| 67 | + # torch nop_decompositions in torch/_decomp/decompositions.py |
| 68 | + if aten.detach in settings_aot_autograd["decompositions"]: |
| 69 | + del settings_aot_autograd["decompositions"][aten.detach] |
| 70 | + return aot_autograd( |
| 71 | + fw_compiler=_pretraced_backend_autograd, |
| 72 | + decompositions=get_decompositions(settings.enable_experimental_decompositions), |
| 73 | + )(gm, sample_inputs) |
53 | 74 |
|
54 | 75 |
|
55 | 76 | def _pretraced_backend(
|
@@ -89,22 +110,39 @@ def _pretraced_backend(
|
89 | 110 | # Remove detach nodes
|
90 | 111 | remove_detach(gm, settings)
|
91 | 112 |
|
| 113 | + complexInputIndices = [] |
| 114 | + for i, torch_input in enumerate(torch_inputs): |
| 115 | + if torch_inputs[i].dtype == torch.complex64: |
| 116 | + complexInputIndices.append(i) |
| 117 | + torch_input_real = torch_inputs[i].real |
| 118 | + torch_input_imaginary = torch_inputs[i].imag |
| 119 | + torch_inputs[i] = torch.stack( |
| 120 | + (torch_input_real, torch_input_imaginary), dim=-1 |
| 121 | + ) |
| 122 | + |
92 | 123 | # Invoke AOTAutograd to translate operators to aten
|
93 |
| - gm = aot_export_joint_simple( |
94 |
| - gm, |
95 |
| - sample_inputs, |
96 |
| - trace_joint=False, |
97 |
| - decompositions=get_decompositions( |
98 |
| - settings.enable_experimental_decompositions |
99 |
| - ), |
100 |
| - ) |
| 124 | + if settings.use_aot_joint_export: |
| 125 | + gm = aot_export_joint_simple( |
| 126 | + gm, |
| 127 | + sample_inputs, |
| 128 | + trace_joint=False, |
| 129 | + decompositions=get_decompositions( |
| 130 | + settings.enable_experimental_decompositions |
| 131 | + ), |
| 132 | + ) |
101 | 133 |
|
102 | 134 | logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
|
103 | 135 |
|
104 | 136 | gm = post_lowering(gm, settings)
|
105 | 137 |
|
106 | 138 | logger.debug("Lowered Input graph:\n " + str(gm.graph))
|
107 | 139 |
|
| 140 | + if complexInputIndices: |
| 141 | + modify_reshape_complex_nodes(gm, complexInputIndices) |
| 142 | + logger.debug( |
| 143 | + "Input graph after modifying complex nodes:\n " + str(gm.graph) |
| 144 | + ) |
| 145 | + |
108 | 146 | torchtrt_inputs = prepare_inputs(
|
109 | 147 | torch_inputs, disable_memory_format_check=True
|
110 | 148 | )
|
|
0 commit comments