Skip to content

Commit

Permalink
examples: Add example usage script for module-acc
Browse files Browse the repository at this point in the history
- Add detailed tutorial for excluding modules or functions from tracing
in Dynamo and writing custom converters for those excluded modules
  • Loading branch information
gs-olive committed Jul 12, 2023
1 parent 2684dd9 commit edda3d9
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 0 deletions.
1 change: 1 addition & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Tutorials
tutorials/_rendered_examples/dynamo/dynamo_compile_resnet_example
tutorials/_rendered_examples/dynamo/dynamo_compile_transformers_example
tutorials/_rendered_examples/dynamo/dynamo_compile_advanced_usage
tutorials/_rendered_examples/dynamo/dynamo_module_level_acceleration

Python API Documenation
------------------------
Expand Down
1 change: 1 addition & 0 deletions examples/dynamo/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ a number of ways you can leverage this backend to accelerate inference.
* :ref:`dynamo_compile_resnet`: Compiling a ResNet model using the Dynamo Compile Frontend for ``torch_tensorrt.compile``
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
* :ref:`dynamo_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
* :ref:`dynamo_module_level_acceleration`: Accelerate a specific ``torch.nn.Module`` or function by excluding it from decomposition
162 changes: 162 additions & 0 deletions examples/dynamo/dynamo_module_level_acceleration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""
.. _dynamo_module_level_acceleration:
Dynamo Module Level Acceleration Tutorial
=========================
This interactive script is intended as an overview of the process by which module-level acceleration for `torch_tensorrt.dynamo.compile` works, and how it can be used to accelerate built-in or custom `torch.nn` modules by excluding them from AOT tracing. This script shows the process for `torch.nn.MaxPool1d`"""

# %%
# 1. The Placeholder
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Specify the schema and namespace of the operator, as well as a placeholder function
# representing the schema. The schema should be in torch JIT syntax, indicating input and output
# types. The namespace, such as tensorrt, will cause the op to be registered as `torch.ops.tensorrt.your_op`
# Then, create a placeholder function with no operations, but having the same schema and naming as that
# used in the decorator

# %%

from torch._custom_op.impl import custom_op


@custom_op(
qualname="tensorrt::maxpool1d",
manual_schema="(Tensor x, int[1] kernel_size, int[1] stride, int[1] padding, int[1] dilation, bool ceil_mode) -> Tensor",
)
def maxpool1d(x, kernel_size, stride, padding, dilation, ceil_mode):
# Defines operator schema, name, namespace, and function header
...


# %%
# 2. The Generic Implementation
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Define the default implementation of the operator in torch syntax. This is used for autograd
# and other tracing functionality. Generally, the `torch.nn.functional` analog of the operator to replace
# is desirable. If the operator to replace is a custom module you've written, then add its Torch
# implementation here. Note that the function header to the generic function can have specific arguments
# as in the above placeholder

# %%
import torch


@maxpool1d.impl("cpu")
@maxpool1d.impl("cuda")
@maxpool1d.impl_abstract()
def maxpool1d_generic(
*args,
**kwargs,
):
# Defines an implementation for AOT Autograd to use for shape analysis/propagation
return torch.nn.functional.max_pool1d(
*args,
**kwargs,
)


# %%
# 3. The Module Substitution Function
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Define a function which can intercept a node of the kind to be replaced, extract
# the relevant data from that node/submodule, and then re-package the information
# for use by an accelerated implementation (to be implemented in step 4). This function
# should use the operator defined in step 1 (for example `torch.ops.tensorrt.maxpool1d`).
# It should refactor the args and kwargs as is needed by the accelerated implementation.

# %%

from torch_tensorrt.dynamo.backend.lowering import register_substitution


@register_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d)
def maxpool1d_insertion_fn(
gm: torch.fx.GraphModule,
node: torch.fx.Node,
submodule: torch.nn.Module,
) -> torch.fx.Node:
# Defines insertion function for new node
new_node = gm.graph.call_function(
torch.ops.tensorrt.maxpool1d,
args=node.args,
kwargs={
"kernel_size": submodule.kernel_size,
"stride": submodule.stride,
"padding": submodule.padding,
"dilation": submodule.dilation,
"ceil_mode": submodule.ceil_mode,
},
)

return new_node


# %%
# If the submodule has weights or other Tensor fields which the accelerated implementation
# needs, the function should insert the necessary nodes to access those weights. For example,
# if the weight Tensor of a submodule is needed, one could write::
#
#
# weights = gm.graph.get_attr(n.target + ".weight", torch.Tensor)
# bias = gm.graph.get_attr(n.target + ".bias", torch.Tensor)
#
# ...
#
# kwargs={"weight": weights,
# "bias": bias,
# ...
# }

# %%
# 4. The Accelerated Implementation
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Define an accelerated implementation of the operator, and register it as necessary.
# This accelerated implementation should consume the args/kwargs specified in step 3.
# One should expect that torch.compile will compress all kwargs into the args field in
# the order specified in the schema written in step 1.

# %%

from typing import Dict, Tuple
from torch.fx.node import Argument, Target
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
from torch_tensorrt.fx.converter_registry import tensorrt_converter
from torch_tensorrt.fx.converters import acc_ops_converters


@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default)
def tensorrt_maxpool1d(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> TRTTensor:
# Defines converter replacing the default operator for this function
kwargs_new = {
"input": args[0],
"kernel_size": args[1],
"stride": args[2],
"padding": args[3],
"dilation": args[4],
"ceil_mode": False if len(args) < 6 else args[5],
}

return acc_ops_converters.acc_ops_max_pool1d(
network, target, None, kwargs_new, name
)


# %%
# 5. Add Imports
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Add your accelerated module file to the `__init__.py` in the
# `py/torch_tensorrt/dynamo/backend/lowering/substitutions` directory, to ensure
# all registrations are run. For instance, if the new module file is called `new_mod.py`,
# one should add `from .new_mod import *` to the `__init__.py`

0 comments on commit edda3d9

Please sign in to comment.