|
51 | 51 | "ModuleCallSignature",
|
52 | 52 | "dims",
|
53 | 53 | "export",
|
| 54 | + "export_for_training", |
54 | 55 | "load",
|
55 | 56 | "register_dataclass",
|
56 | 57 | "save",
|
|
69 | 70 | PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
|
70 | 71 |
|
71 | 72 |
|
| 73 | +def export_for_training( |
| 74 | + mod: torch.nn.Module, |
| 75 | + args: Tuple[Any, ...], |
| 76 | + kwargs: Optional[Dict[str, Any]] = None, |
| 77 | + *, |
| 78 | + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, |
| 79 | + strict: bool = True, |
| 80 | + preserve_module_call_signature: Tuple[str, ...] = (), |
| 81 | +) -> ExportedProgram: |
| 82 | + """ |
| 83 | + :func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing |
| 84 | + only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, |
| 85 | + which can subsequently be executed with different inputs or serialized. The |
| 86 | + traced graph (1) produces normalized operators in the all ATen operator set |
| 87 | + (as well as any user-specified custom operators), (2) has eliminated all Python control |
| 88 | + flow and data structures (with certain exceptions), and (3) records the set of |
| 89 | + shape constraints needed to show that this normalization and control-flow elimination |
| 90 | + is sound for future inputs. This API is intended for PT2 quantization training use cases |
| 91 | + and will soon be the default IR of torch.export.export in the near future. |
| 92 | +
|
| 93 | + **Soundness Guarantee** |
| 94 | +
|
| 95 | + See :func:`export()` docstring for more details. |
| 96 | +
|
| 97 | + Args: |
| 98 | + mod: We will trace the forward method of this module. |
| 99 | +
|
| 100 | + args: Example positional inputs. |
| 101 | +
|
| 102 | + kwargs: Optional example keyword inputs. |
| 103 | +
|
| 104 | + dynamic_shapes: |
| 105 | + An optional argument where the type should either be: |
| 106 | + 1) a dict from argument names of ``f`` to their dynamic shape specifications, |
| 107 | + 2) a tuple that specifies dynamic shape specifications for each input in original order. |
| 108 | + If you are specifying dynamism on keyword args, you will need to pass them in the order that |
| 109 | + is defined in the original function signature. |
| 110 | +
|
| 111 | + The dynamic shape of a tensor argument can be specified as either |
| 112 | + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is |
| 113 | + not required to include static dimension indices in this dict, but when they are, |
| 114 | + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, |
| 115 | + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions |
| 116 | + are denoted by None. Arguments that are dicts or tuples / lists of tensors are |
| 117 | + recursively specified by using mappings or sequences of contained specifications. |
| 118 | +
|
| 119 | + strict: When enabled (default), the export function will trace the program through |
| 120 | + TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the |
| 121 | + exported program will not validate the implicit assumptions baked into the graph and |
| 122 | + may cause behavior divergence between the original model and the exported one. This is |
| 123 | + useful when users need to workaround bugs in the tracer, or simply want incrementally |
| 124 | + enable safety in their models. Note that this does not affect the resulting IR spec |
| 125 | + to be different and the model will be serialized in the same way regardless of what value |
| 126 | + is passed here. |
| 127 | + WARNING: This option is experimental and use this at your own risk. |
| 128 | +
|
| 129 | + Returns: |
| 130 | + An :class:`ExportedProgram` containing the traced callable. |
| 131 | +
|
| 132 | + **Acceptable input/output types** |
| 133 | +
|
| 134 | + Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include: |
| 135 | +
|
| 136 | + - Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``. |
| 137 | + - Dataclasses, but they must be registered by calling :func:`register_dataclass` first. |
| 138 | + - (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and |
| 139 | + ``OrderedDict`` containing all above types. |
| 140 | +
|
| 141 | + """ |
| 142 | + from ._trace import _export_for_training |
| 143 | + |
| 144 | + if not isinstance(mod, torch.nn.Module): |
| 145 | + raise ValueError( |
| 146 | + f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}." |
| 147 | + ) |
| 148 | + return _export_for_training( |
| 149 | + mod, |
| 150 | + args, |
| 151 | + kwargs, |
| 152 | + dynamic_shapes, |
| 153 | + strict=strict, |
| 154 | + preserve_module_call_signature=preserve_module_call_signature, |
| 155 | + ) |
| 156 | + |
| 157 | + |
72 | 158 | def export(
|
73 | 159 | mod: torch.nn.Module,
|
74 | 160 | args: Tuple[Any, ...],
|
|
0 commit comments