From edda3d9cd261b94c75464eecf8917a1d2605502a Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 12 Jul 2023 10:33:53 -0700 Subject: [PATCH] examples: Add example usage script for module-acc - Add detailed tutorial for excluding modules or functions from tracing in Dynamo and writing custom converters for those excluded modules --- docsrc/index.rst | 1 + examples/dynamo/README.rst | 1 + .../dynamo_module_level_acceleration.py | 162 ++++++++++++++++++ 3 files changed, 164 insertions(+) create mode 100644 examples/dynamo/dynamo_module_level_acceleration.py diff --git a/docsrc/index.rst b/docsrc/index.rst index ace492c84f..b289b2b33b 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -73,6 +73,7 @@ Tutorials tutorials/_rendered_examples/dynamo/dynamo_compile_resnet_example tutorials/_rendered_examples/dynamo/dynamo_compile_transformers_example tutorials/_rendered_examples/dynamo/dynamo_compile_advanced_usage + tutorials/_rendered_examples/dynamo/dynamo_module_level_acceleration Python API Documenation ------------------------ diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index d3b6f9ddcf..d422fac4a3 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -9,3 +9,4 @@ a number of ways you can leverage this backend to accelerate inference. * :ref:`dynamo_compile_resnet`: Compiling a ResNet model using the Dynamo Compile Frontend for ``torch_tensorrt.compile`` * :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile`` * :ref:`dynamo_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API +* :ref:`dynamo_module_level_acceleration`: Accelerate a specific ``torch.nn.Module`` or function by excluding it from decomposition diff --git a/examples/dynamo/dynamo_module_level_acceleration.py b/examples/dynamo/dynamo_module_level_acceleration.py new file mode 100644 index 0000000000..442ff29aed --- /dev/null +++ b/examples/dynamo/dynamo_module_level_acceleration.py @@ -0,0 +1,162 @@ +""" +.. _dynamo_module_level_acceleration: + +Dynamo Module Level Acceleration Tutorial +========================= + +This interactive script is intended as an overview of the process by which module-level acceleration for `torch_tensorrt.dynamo.compile` works, and how it can be used to accelerate built-in or custom `torch.nn` modules by excluding them from AOT tracing. This script shows the process for `torch.nn.MaxPool1d`""" + +# %% +# 1. The Placeholder +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Specify the schema and namespace of the operator, as well as a placeholder function +# representing the schema. The schema should be in torch JIT syntax, indicating input and output +# types. The namespace, such as tensorrt, will cause the op to be registered as `torch.ops.tensorrt.your_op` +# Then, create a placeholder function with no operations, but having the same schema and naming as that +# used in the decorator + +# %% + +from torch._custom_op.impl import custom_op + + +@custom_op( + qualname="tensorrt::maxpool1d", + manual_schema="(Tensor x, int[1] kernel_size, int[1] stride, int[1] padding, int[1] dilation, bool ceil_mode) -> Tensor", +) +def maxpool1d(x, kernel_size, stride, padding, dilation, ceil_mode): + # Defines operator schema, name, namespace, and function header + ... + + +# %% +# 2. The Generic Implementation +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Define the default implementation of the operator in torch syntax. This is used for autograd +# and other tracing functionality. Generally, the `torch.nn.functional` analog of the operator to replace +# is desirable. If the operator to replace is a custom module you've written, then add its Torch +# implementation here. Note that the function header to the generic function can have specific arguments +# as in the above placeholder + +# %% +import torch + + +@maxpool1d.impl("cpu") +@maxpool1d.impl("cuda") +@maxpool1d.impl_abstract() +def maxpool1d_generic( + *args, + **kwargs, +): + # Defines an implementation for AOT Autograd to use for shape analysis/propagation + return torch.nn.functional.max_pool1d( + *args, + **kwargs, + ) + + +# %% +# 3. The Module Substitution Function +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Define a function which can intercept a node of the kind to be replaced, extract +# the relevant data from that node/submodule, and then re-package the information +# for use by an accelerated implementation (to be implemented in step 4). This function +# should use the operator defined in step 1 (for example `torch.ops.tensorrt.maxpool1d`). +# It should refactor the args and kwargs as is needed by the accelerated implementation. + +# %% + +from torch_tensorrt.dynamo.backend.lowering import register_substitution + + +@register_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d) +def maxpool1d_insertion_fn( + gm: torch.fx.GraphModule, + node: torch.fx.Node, + submodule: torch.nn.Module, +) -> torch.fx.Node: + # Defines insertion function for new node + new_node = gm.graph.call_function( + torch.ops.tensorrt.maxpool1d, + args=node.args, + kwargs={ + "kernel_size": submodule.kernel_size, + "stride": submodule.stride, + "padding": submodule.padding, + "dilation": submodule.dilation, + "ceil_mode": submodule.ceil_mode, + }, + ) + + return new_node + + +# %% +# If the submodule has weights or other Tensor fields which the accelerated implementation +# needs, the function should insert the necessary nodes to access those weights. For example, +# if the weight Tensor of a submodule is needed, one could write:: +# +# +# weights = gm.graph.get_attr(n.target + ".weight", torch.Tensor) +# bias = gm.graph.get_attr(n.target + ".bias", torch.Tensor) +# +# ... +# +# kwargs={"weight": weights, +# "bias": bias, +# ... +# } + +# %% +# 4. The Accelerated Implementation +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Define an accelerated implementation of the operator, and register it as necessary. +# This accelerated implementation should consume the args/kwargs specified in step 3. +# One should expect that torch.compile will compress all kwargs into the args field in +# the order specified in the schema written in step 1. + +# %% + +from typing import Dict, Tuple +from torch.fx.node import Argument, Target +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.converter_registry import tensorrt_converter +from torch_tensorrt.fx.converters import acc_ops_converters + + +@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default) +def tensorrt_maxpool1d( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> TRTTensor: + # Defines converter replacing the default operator for this function + kwargs_new = { + "input": args[0], + "kernel_size": args[1], + "stride": args[2], + "padding": args[3], + "dilation": args[4], + "ceil_mode": False if len(args) < 6 else args[5], + } + + return acc_ops_converters.acc_ops_max_pool1d( + network, target, None, kwargs_new, name + ) + + +# %% +# 5. Add Imports +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Add your accelerated module file to the `__init__.py` in the +# `py/torch_tensorrt/dynamo/backend/lowering/substitutions` directory, to ensure +# all registrations are run. For instance, if the new module file is called `new_mod.py`, +# one should add `from .new_mod import *` to the `__init__.py`