Skip to content

Commit 6dd3f81

Browse files
tugsbayasgalanpytorchmergebot
authored andcommitted
Add export_for_training as public API (pytorch#134677)
Differential Revision: [D61912084](https://our.internmc.facebook.com/intern/diff/D61912084) Pull Request resolved: pytorch#134677 Approved by: https://github.com/avikchaudhuri, https://github.com/zhxchen17
1 parent a7933ac commit 6dd3f81

File tree

4 files changed

+102
-15
lines changed

4 files changed

+102
-15
lines changed

test/export/test_export.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1437,7 +1437,7 @@ def forward(self, x, y):
14371437
x_linear = self.linear(x_conv)
14381438
return x_linear.cos() + y_conv_1d.sum()
14391439

1440-
ep = torch.export._trace._export_for_training(
1440+
ep = torch.export.export_for_training(
14411441
Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50))
14421442
)
14431443
ep_has_linear_convd = ep.run_decompositions(
@@ -1570,7 +1570,7 @@ def forward(self, x):
15701570
return self.linear(x)
15711571

15721572
eager_model = Foo()
1573-
ep_for_training = torch.export._trace._export_for_training(
1573+
ep_for_training = torch.export.export_for_training(
15741574
eager_model, (torch.ones(2, 2),)
15751575
)
15761576
self.assertExpectedInline(
@@ -1609,7 +1609,7 @@ def forward(self, x):
16091609

16101610
eager_model_for_export = Foo()
16111611
eager_model_for_testing = Foo()
1612-
ep_for_training = torch.export._trace._export_for_training(
1612+
ep_for_training = torch.export.export_for_training(
16131613
eager_model_for_export, (torch.ones(4, 4),)
16141614
)
16151615
self.assertExpectedInline(
@@ -1654,7 +1654,7 @@ def forward(self, x):
16541654
eager_model_for_export_training = Foo()
16551655
eager_model_for_export_inference = Foo()
16561656
eager_model_for_testing = Foo()
1657-
ep_for_training = torch.export._trace._export_for_training(
1657+
ep_for_training = torch.export.export_for_training(
16581658
eager_model_for_export_training,
16591659
(torch.ones(4, 4),),
16601660
dynamic_shapes=({0: Dim("x")},),
@@ -1691,7 +1691,7 @@ def forward(self, container):
16911691
return x + y + self.buffer.sum()
16921692

16931693
eager_model = Foo()
1694-
ep_for_training = torch.export._trace._export_for_training(
1694+
ep_for_training = torch.export.export_for_training(
16951695
eager_model,
16961696
([torch.ones(4, 4), torch.ones(4, 4)],),
16971697
)
@@ -1717,7 +1717,7 @@ def forward(self, x):
17171717
return self.linear(x) + self.buffer.sum()
17181718

17191719
eager_model = Foo()
1720-
ep_for_training = torch.export._trace._export_for_training(
1720+
ep_for_training = torch.export.export_for_training(
17211721
eager_model,
17221722
(torch.ones(2, 2),),
17231723
)

test/export/test_export_training_ir_to_run_decomp.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,30 @@
11
# Owner(s): ["oncall: export"]
2+
import torch
3+
24

35
try:
46
from . import test_export, testing
57
except ImportError:
68
import test_export
7-
import testing
89

9-
from torch.export._trace import _export_for_training
10+
import testing
1011

1112

1213
test_classes = {}
1314

1415

1516
def mocked_training_ir_to_run_decomp_export_strict(*args, **kwargs):
16-
ep = _export_for_training(*args, **kwargs)
17+
ep = torch.export.export_for_training(*args, **kwargs)
1718
return ep.run_decompositions(
1819
{}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY
1920
)
2021

2122

2223
def mocked_training_ir_to_run_decomp_export_non_strict(*args, **kwargs):
2324
if "strict" in kwargs:
24-
ep = _export_for_training(*args, **kwargs)
25+
ep = torch.export.export_for_training(*args, **kwargs)
2526
else:
26-
ep = _export_for_training(*args, **kwargs, strict=False)
27+
ep = torch.export.export_for_training(*args, **kwargs, strict=False)
2728
return ep.run_decompositions(
2829
{}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY
2930
)

torch/_export/__init__.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def capture_pre_autograd_graph_warning():
6565
log.warning("| !!! WARNING !!! |")
6666
log.warning("+============================+")
6767
log.warning("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.")
68-
log.warning("Please switch to use torch.export._trace._export_for_training instead.")
68+
log.warning("Please switch to use torch.export.export_for_training instead.")
6969
if config.is_fbcode():
70-
log.warning("Unless the unittest is in the blocklist, capture_pre_autograd_graph() will fallback to torch.export._trace._export_for_training.") # noqa: B950
70+
log.warning("Unless the unittest is in the blocklist, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.") # noqa: B950
7171

7272

7373
@compatibility(is_backward_compatible=False)
@@ -128,9 +128,9 @@ def capture_pre_autograd_graph(
128128
if capture_pre_autograd_graph_using_training_ir():
129129
@lru_cache
130130
def print_export_warning():
131-
log.warning("Using torch.export._trace._export_for_training(...,strict=True)")
131+
log.warning("Using torch.export.export_for_training(...,strict=True)")
132132
print_export_warning()
133-
module = torch.export._trace._export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module()
133+
module = torch.export.export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module()
134134
else:
135135
log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
136136

torch/export/__init__.py

+86
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
"ModuleCallSignature",
5252
"dims",
5353
"export",
54+
"export_for_training",
5455
"load",
5556
"register_dataclass",
5657
"save",
@@ -69,6 +70,91 @@
6970
PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
7071

7172

73+
def export_for_training(
74+
mod: torch.nn.Module,
75+
args: Tuple[Any, ...],
76+
kwargs: Optional[Dict[str, Any]] = None,
77+
*,
78+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
79+
strict: bool = True,
80+
preserve_module_call_signature: Tuple[str, ...] = (),
81+
) -> ExportedProgram:
82+
"""
83+
:func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing
84+
only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion,
85+
which can subsequently be executed with different inputs or serialized. The
86+
traced graph (1) produces normalized operators in the all ATen operator set
87+
(as well as any user-specified custom operators), (2) has eliminated all Python control
88+
flow and data structures (with certain exceptions), and (3) records the set of
89+
shape constraints needed to show that this normalization and control-flow elimination
90+
is sound for future inputs. This API is intended for PT2 quantization training use cases
91+
and will soon be the default IR of torch.export.export in the near future.
92+
93+
**Soundness Guarantee**
94+
95+
See :func:`export()` docstring for more details.
96+
97+
Args:
98+
mod: We will trace the forward method of this module.
99+
100+
args: Example positional inputs.
101+
102+
kwargs: Optional example keyword inputs.
103+
104+
dynamic_shapes:
105+
An optional argument where the type should either be:
106+
1) a dict from argument names of ``f`` to their dynamic shape specifications,
107+
2) a tuple that specifies dynamic shape specifications for each input in original order.
108+
If you are specifying dynamism on keyword args, you will need to pass them in the order that
109+
is defined in the original function signature.
110+
111+
The dynamic shape of a tensor argument can be specified as either
112+
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
113+
not required to include static dimension indices in this dict, but when they are,
114+
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
115+
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
116+
are denoted by None. Arguments that are dicts or tuples / lists of tensors are
117+
recursively specified by using mappings or sequences of contained specifications.
118+
119+
strict: When enabled (default), the export function will trace the program through
120+
TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the
121+
exported program will not validate the implicit assumptions baked into the graph and
122+
may cause behavior divergence between the original model and the exported one. This is
123+
useful when users need to workaround bugs in the tracer, or simply want incrementally
124+
enable safety in their models. Note that this does not affect the resulting IR spec
125+
to be different and the model will be serialized in the same way regardless of what value
126+
is passed here.
127+
WARNING: This option is experimental and use this at your own risk.
128+
129+
Returns:
130+
An :class:`ExportedProgram` containing the traced callable.
131+
132+
**Acceptable input/output types**
133+
134+
Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include:
135+
136+
- Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``.
137+
- Dataclasses, but they must be registered by calling :func:`register_dataclass` first.
138+
- (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and
139+
``OrderedDict`` containing all above types.
140+
141+
"""
142+
from ._trace import _export_for_training
143+
144+
if not isinstance(mod, torch.nn.Module):
145+
raise ValueError(
146+
f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}."
147+
)
148+
return _export_for_training(
149+
mod,
150+
args,
151+
kwargs,
152+
dynamic_shapes,
153+
strict=strict,
154+
preserve_module_call_signature=preserve_module_call_signature,
155+
)
156+
157+
72158
def export(
73159
mod: torch.nn.Module,
74160
args: Tuple[Any, ...],

0 commit comments

Comments
 (0)