-
Notifications
You must be signed in to change notification settings - Fork 356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat: Initial implementation for automatic plugin #3301
Open
bowang007
wants to merge
8
commits into
main
Choose a base branch
from
auto_plugin
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,047
−354
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
github-actions
bot
added
component: conversion
Issues re: Conversion stage
component: build system
Issues re: Build system
component: api [Python]
Issues re: Python API
component: dynamo
Issues relating to the `torch.compile` or `torch._dynamo.export` paths
labels
Nov 22, 2024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/automatic_plugin/custom_op.py 2024-11-22 01:20:58.215888+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/automatic_plugin/custom_op.py 2024-11-22 01:21:18.909129+00:00
@@ -1,7 +1,8 @@
import triton
import triton.language as tl
+
@triton.jit
def elementwise_add_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):
# Program ID determines the block of data each thread will process
pid = tl.program_id(0)
@@ -25,23 +26,23 @@
@custom_op("torchtrt_ex::elementwise_add", mutates_args=()) # type: ignore[misc]
def elementwise_add(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
# Ensure the tensors are on the GPU
assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device."
assert X.shape == Y.shape, "Tensors must have the same shape."
-
+
# Create output tensor
Z = torch.empty_like(X)
-
+
# Define block size
BLOCK_SIZE = 1024
-
+
# Grid of programs
- grid = lambda meta: (X.numel() // meta['BLOCK_SIZE'],)
-
+ grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],)
+
# Launch the kernel
elementwise_add_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE)
-
+
return Z
# Using the module in PyTorch
# X = torch.randn(1024, device='cuda', requires_grad=True)
@@ -72,22 +73,31 @@
return res
my_model = MyModel().to("cuda")
-m = torch.full((64, 64), 2, device='cuda',)
-n = torch.full((64, 64), 3, device='cuda',)
+m = torch.full(
+ (64, 64),
+ 2,
+ device="cuda",
+)
+n = torch.full(
+ (64, 64),
+ 3,
+ device="cuda",
+)
# print(torch.ops.torchtrt_ex.elementwise_add(m, n))
# print(my_model.forward(m, n))
@torch.library.register_fake("torchtrt_ex::elementwise_add")
def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x
+
import torch_tensorrt as torchtrt
with torchtrt.logging.info():
model_trt = torchtrt.compile(my_model, inputs=[m, n], debug=True, min_block_size=1)
res = model_trt(m, n)
- print(res)
\ No newline at end of file
+ print(res)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/__init__.py 2024-11-22 01:20:58.227888+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/__init__.py 2024-11-22 01:21:19.453080+00:00
@@ -1,6 +1,11 @@
-from . import aten_ops_converters, ops_evaluators, prims_ops_converters, plugin_ops_converters
+from . import (
+ aten_ops_converters,
+ ops_evaluators,
+ prims_ops_converters,
+ plugin_ops_converters,
+)
from ._conversion import convert_module, interpret_module_to_result
from ._ConversionContext import ConversionContext
from ._ConverterRegistry import * # noqa: F403
from ._TRTInterpreter import * # noqa: F403
from .truncate_double import repair_double_inputs
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py 2024-11-22 01:20:58.227888+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py 2024-11-22 01:21:20.202267+00:00
@@ -1 +1 @@
-from .plugin_generator import PluginCreator
\ No newline at end of file
+from .plugin_generator import PluginCreator
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py 2024-11-22 01:20:58.227888+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py 2024-11-22 01:21:20.284627+00:00
@@ -17,25 +17,28 @@
logger = logging.getLogger(__name__)
TRT_PLUGIN_REGISTRY = trt.get_plugin_registry()
+
@dynamo_tensorrt_converter(torch.ops.torchtrt_ex.elementwise_add.default)
def torchtrt_ex_elementwise_add(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
-):
+):
# logger.debug(f"plugin stuff here2")
# return torch.add(args)
-
+
# How to retrieve a plugin if it is defined elsewhere (e.g. linked library)
- plugin_creator = PluginCreator("elementwise_add_plugin", plugin_namespace="", attrs={})
- TRT_PLUGIN_REGISTRY.register_creator(plugin_creator, "")
-
+ plugin_creator = PluginCreator(
+ "elementwise_add_plugin", plugin_namespace="", attrs={}
+ )
+ TRT_PLUGIN_REGISTRY.register_creator(plugin_creator, "")
+
plugin_creator = TRT_PLUGIN_REGISTRY.get_plugin_creator(
type="elementwise_add_plugin", version="1", plugin_namespace=""
)
assert plugin_creator, f"Unable to find elementwise_add_plugin creator"
@@ -44,45 +47,47 @@
# plugin = plugin_creator.create_plugin(name=name, field_collection=field_configs)
# assert plugin, "Unable to create <PLUGIN_NAME>"
# <GENERATE LINK BETWEEN PLUGIN AND INPUTS>
# <GET INPUTS INTO LIST>
- # <PASS TO PLUGIN>
-
+ # <PASS TO PLUGIN>
+
# return layer.get_output(0)
field_configs = trt.PluginFieldCollection([])
-
- plugin = plugin_creator.create_plugin(name="elementwise_add_plugin", field_collection=field_configs)
+
+ plugin = plugin_creator.create_plugin(
+ name="elementwise_add_plugin", field_collection=field_configs
+ )
assert plugin, "Unable to create CircularPaddingPlugin"
-
+
# input_tensor = args[
# 0
# ] # Arg 0 `torch.ops.torchtrt_ex.triton_circular_pad` is the input tensor
# if not isinstance(input_tensor, trt.ITensor):
# # Freeze input tensor if not TensorRT Tensor already
# input_tensor = get_trt_tensor(ctx, input_tensor, f"{name}_input")
-
+
lhs_dtype = None
rhs_dtype = None
lhs_val = args[0]
rhs_val = args[1]
-
+
if isinstance(lhs_val, TRTTensor):
lhs_dtype = lhs_val.dtype
# is_lhs_trt_tensor = True
if isinstance(rhs_val, TRTTensor):
rhs_dtype = rhs_val.dtype
# is_rhs_trt_tensor = True
-
+
print(lhs_dtype)
-
+
lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype)
rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype)
layer = ctx.net.add_plugin_v3(
[lhs_val, rhs_val], [], plugin
) # Add the plugin to the network being constructed
# layer.name = f"automatic-{name}"
return layer.get_output(0)
-# 1. generate plugin for any pytorch op
\ No newline at end of file
+# 1. generate plugin for any pytorch op
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py 2024-11-22 01:20:58.227888+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py 2024-11-22 01:21:20.380983+00:00
@@ -11,64 +11,63 @@
logger = logging.getLogger("CustomPlugin")
_numpy_to_plugin_field_type = {
- np.dtype('int32'): trt.PluginFieldType.INT32,
- np.dtype('int16'): trt.PluginFieldType.INT16,
- np.dtype('int8'): trt.PluginFieldType.INT8,
- np.dtype('bool'): trt.PluginFieldType.INT8,
- np.dtype('int64'): trt.PluginFieldType.INT64,
- np.dtype('float32'): trt.PluginFieldType.FLOAT32,
- np.dtype('float64'): trt.PluginFieldType.FLOAT64,
- np.dtype('float16'): trt.PluginFieldType.FLOAT16
+ np.dtype("int32"): trt.PluginFieldType.INT32,
+ np.dtype("int16"): trt.PluginFieldType.INT16,
+ np.dtype("int8"): trt.PluginFieldType.INT8,
+ np.dtype("bool"): trt.PluginFieldType.INT8,
+ np.dtype("int64"): trt.PluginFieldType.INT64,
+ np.dtype("float32"): trt.PluginFieldType.FLOAT32,
+ np.dtype("float64"): trt.PluginFieldType.FLOAT64,
+ np.dtype("float16"): trt.PluginFieldType.FLOAT16,
}
_built_in_to_plugin_field_type = {
int: trt.PluginFieldType.INT64,
float: trt.PluginFieldType.FLOAT64,
bool: trt.PluginFieldType.INT8,
# str is handled separately, so not needed here
}
+
class Tactic(IntEnum):
TORCH = 1
TRITON = 2
+
class CustomPlugin(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime): # type: ignore[misc]
- def __init__(
- self, plugin_name : str, attrs, phase = None
- ):
+ def __init__(self, plugin_name: str, attrs, phase=None):
# TODO: needs an additional passed in arguments to specify the needs for each plugin
# such as the one here: https://github.com/NVIDIA/TensorRT/blob/40efe7e9f2492657bbc455c4e2876e2ec792b812/samples/python/python_plugin/circ_pad_plugin_multi_tactic.py#L83
trt.IPluginV3.__init__(self)
# Core capability, plugin attributes and behaviors common to both the build and runtime phases of a plugin’s lifetime
trt.IPluginV3OneCore.__init__(self)
# Build capability, plugin attributes and behaviors that the plugin must exhibit for the TensorRT builder.
trt.IPluginV3OneBuild.__init__(self)
# Runtime capability, plugin attributes and behaviors that the plugin must exhibit for it to be executable
- trt.IPluginV3OneRuntime.__init__(self)
-
+ trt.IPluginV3OneRuntime.__init__(self)
+
# <ANY NON TENSOR INPUTS SHOULD BE AN ATTRIBUTE OF THE PLUGIN>
- # setattr(<name of input>, <default value for that type>)
+ # setattr(<name of input>, <default value for that type>)
# self.pads = []
# self.X_shape: List[int] = []
-
- self.num_outputs = 1 # Defined by schema
+
+ self.num_outputs = 1 # Defined by schema
self.plugin_namespace = ""
self.plugin_name = plugin_name
- self.plugin_version = "1"
+ self.plugin_version = "1"
# Set the timing cache ID to prevent unnecessary timing of second plugin instance
self.timing_cache_id = ""
self.attrs = attrs
-
+
self.tactic = None
-
-
- # <GENERATE CODE FOR TAKING A FIELD COLLECTION CONTAINING THE NON TENSOR INPUTS AND SETTING AN ATTR>
+
+ # <GENERATE CODE FOR TAKING A FIELD COLLECTION CONTAINING THE NON TENSOR INPUTS AND SETTING AN ATTR>
# ex.
# TODO: need to parse the field collection here
# if fc is not None:
# assert fc[0].name == "pads"
# self.pads = fc[0].data
@@ -77,14 +76,12 @@
self.phase = phase
def get_capability_interface(self, type):
return self
- def get_output_data_types(
- self, input_types: List[trt.DataType]
- ) -> trt.DataType:
- # WE CAN USE THE FAKE TENSOR IMPLEMENTATION TO FIGURE OUT THE EXPECTED OUTPUT DATA TYPE
+ def get_output_data_types(self, input_types: List[trt.DataType]) -> trt.DataType:
+ # WE CAN USE THE FAKE TENSOR IMPLEMENTATION TO FIGURE OUT THE EXPECTED OUTPUT DATA TYPE
# with torch.fake_tensor():
# <GENERATE FAKE INPUTS OF TYPE INPUT_TYPES>
# fake_outputs = torch.ops.<custom_ns>.<custom_op>(*fake_inputs)
# return fake_outputs[index]
@@ -96,20 +93,20 @@
self,
inputs: List[trt.DimsExprs],
shape_inputs,
exprBuilder: trt.IExprBuilder,
) -> trt.DimsExprs:
-
+
print(inputs)
- # WE NEED TO FIND A WAY TO GO FROM FAKE TENSOR IMPL TO CONSTRUCTING A DIMSEXPR
- # THIS IS SOLVED IN SHAPE PROP IN PYTORCH WHERE SHAPE PROP CAN GIVE SYMINTS THAT ENCODE THE
- # SHAPE MAP.
+ # WE NEED TO FIND A WAY TO GO FROM FAKE TENSOR IMPL TO CONSTRUCTING A DIMSEXPR
+ # THIS IS SOLVED IN SHAPE PROP IN PYTORCH WHERE SHAPE PROP CAN GIVE SYMINTS THAT ENCODE THE
+ # SHAPE MAP.
output_dims = trt.DimsExprs(inputs[0])
return [output_dims]
-
+
def get_fields_to_serialize(self):
# should be passed in as another argument
field_names = []
for key, value in self.attrs.items():
@@ -149,11 +146,11 @@
self.X_shape = np.zeros((len(X_dims),))
for i in range(len(X_dims)):
self.X_shape[i] = X_dims[i]
def supports_format_combination(self, pos, in_out, num_inputs):
- return
+ return
assert num_inputs == 1
assert pos < len(in_out)
desc = in_out[pos].desc
if desc.format != trt.TensorFormat.LINEAR:
@@ -166,11 +163,10 @@
# output should have the same type as the input
if pos == 1:
return in_out[0].desc.type == desc.type
assert False
-
def enqueue(
self,
input_desc: List[trt.PluginTensorDesc],
output_desc: List[trt.PluginTensorDesc],
@@ -180,40 +176,56 @@
stream: int,
) -> None:
# input and output memory handling
input_mems = [None] * (len(inputs))
- for i in range(len(inputs)):
- input_mems[i] = cp.cuda.UnownedMemory(inputs[i], np.prod(input_desc[i].dims) * cp.dtype(trt.nptype(input_desc[i].type)).itemsize, self)
+ for i in range(len(inputs)):
+ input_mems[i] = cp.cuda.UnownedMemory(
+ inputs[i],
+ np.prod(input_desc[i].dims)
+ * cp.dtype(trt.nptype(input_desc[i].type)).itemsize,
+ self,
+ )
output_mems = [None] * (len(outputs))
for i in range(len(outputs)):
- output_mems[i] = cp.cuda.UnownedMemory(outputs[i], np.prod(output_desc[i].dims) * cp.dtype(trt.nptype(output_desc[i].type)).itemsize, self)
-
+ output_mems[i] = cp.cuda.UnownedMemory(
+ outputs[i],
+ np.prod(output_desc[i].dims)
+ * cp.dtype(trt.nptype(output_desc[i].type)).itemsize,
+ self,
+ )
input_data = [None] * ((len(inputs)))
for i in range(len(inputs)):
- input_data[i] = cp.ndarray(tuple(input_desc[i].dims), dtype=input_desc[i].type, memptr = cp.cuda.MemoryPointer(input_mems[i], 0))
+ input_data[i] = cp.ndarray(
+ tuple(input_desc[i].dims),
+ dtype=input_desc[i].type,
+ memptr=cp.cuda.MemoryPointer(input_mems[i], 0),
+ )
output_data = [None] * ((len(outputs)))
for i in range(len(outputs)):
- output_data[i] = cp.ndarray((np.prod(output_desc[i].dims)), dtype = output_desc[i].type, memptr = cp.cuda.MemoryPointer(output_mems[i], 0))
-
- #TODO: This is just for a simple case for elementwise operations
+ output_data[i] = cp.ndarray(
+ (np.prod(output_desc[i].dims)),
+ dtype=output_desc[i].type,
+ memptr=cp.cuda.MemoryPointer(output_mems[i], 0),
+ )
+
+ # TODO: This is just for a simple case for elementwise operations
# using Torch implementation for now
- input_torch_0 = torch.as_tensor(input_data[0], device='cuda')
- input_torch_1 = torch.as_tensor(input_data[1], device='cuda')
+ input_torch_0 = torch.as_tensor(input_data[0], device="cuda")
+ input_torch_1 = torch.as_tensor(input_data[1], device="cuda")
output = torch.ops.torchtrt_ex.elementwise_add(input_torch_0, input_torch_1)
cp.copyto(output_data, output)
-
def attach_to_context(self, context):
return self.clone()
-
+
def get_valid_tactics(self):
return [int(Tactic.TORCH), int(Tactic.TRITON)]
def set_tactic(self, tactic):
self.tactic = Tactic(tactic)
@@ -226,17 +238,17 @@
cloned_plugin.__dict__.update(self.__dict__)
return cloned_plugin
class PluginCreator(trt.IPluginCreatorV3One): # type: ignore[misc]
- def __init__(self, plugin_name : str, plugin_namespace : str, attrs):
- trt.IPluginCreatorV3One.__init__(self)
+ def __init__(self, plugin_name: str, plugin_namespace: str, attrs):
+ trt.IPluginCreatorV3One.__init__(self)
self.name = plugin_name
self.plugin_namespace = plugin_namespace
self.plugin_version = "1"
-
+
field_names = []
for name, (builtin, type_) in attrs.items():
if builtin:
if type_ is str:
field_names.append(
@@ -259,15 +271,12 @@
)
)
self.field_names = trt.PluginFieldCollection(field_names)
- def create_plugin(
- self, name: str, field_collection, phase=None
- ) -> CustomPlugin:
-
-
+ def create_plugin(self, name: str, field_collection, phase=None) -> CustomPlugin:
+
attrs = {}
# for f in fc:
# if f.name not in desc.input_attrs:
# raise AssertionError(
# f"Unexpected attribute {f.name} provided to create_plugin. Expected one of {desc.input_attrs.keys()}."
@@ -275,10 +284,9 @@
# if _is_numpy_array(desc.input_attrs[f.name]):
# attrs[f.name] = f.data.astype(_infer_numpy_type(desc.input_attrs[f.name]))
# else:
# attrs[f.name] = desc.input_attrs[f.name](f.data)
-
+
custom_plugin = CustomPlugin(name, attrs)
-
+
return custom_plugin
-
narendasan
force-pushed
the
auto_plugin
branch
from
November 26, 2024 20:16
ecfb610
to
f0b0a0f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/automatic_plugin/custom_op.py 2024-11-26 20:16:28.712186+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/automatic_plugin/custom_op.py 2024-11-26 20:16:48.244419+00:00
@@ -1,7 +1,8 @@
import triton
import triton.language as tl
+
@triton.jit
def elementwise_add_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):
# Program ID determines the block of data each thread will process
pid = tl.program_id(0)
@@ -25,23 +26,23 @@
@custom_op("torchtrt_ex::elementwise_add", mutates_args=()) # type: ignore[misc]
def elementwise_add(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
# Ensure the tensors are on the GPU
assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device."
assert X.shape == Y.shape, "Tensors must have the same shape."
-
+
# Create output tensor
Z = torch.empty_like(X)
-
+
# Define block size
BLOCK_SIZE = 1024
-
+
# Grid of programs
- grid = lambda meta: (X.numel() // meta['BLOCK_SIZE'],)
-
+ grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],)
+
# Launch the kernel
elementwise_add_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE)
-
+
return Z
# Using the module in PyTorch
# X = torch.randn(1024, device='cuda', requires_grad=True)
@@ -72,22 +73,31 @@
return res
my_model = MyModel().to("cuda")
-m = torch.full((64, 64), 2, device='cuda',)
-n = torch.full((64, 64), 3, device='cuda',)
+m = torch.full(
+ (64, 64),
+ 2,
+ device="cuda",
+)
+n = torch.full(
+ (64, 64),
+ 3,
+ device="cuda",
+)
# print(torch.ops.torchtrt_ex.elementwise_add(m, n))
# print(my_model.forward(m, n))
@torch.library.register_fake("torchtrt_ex::elementwise_add")
def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x
+
import torch_tensorrt as torchtrt
with torchtrt.logging.info():
model_trt = torchtrt.compile(my_model, inputs=[m, n], debug=True, min_block_size=1)
res = model_trt(m, n)
- print(res)
\ No newline at end of file
+ print(res)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/__init__.py 2024-11-26 20:16:28.728186+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/__init__.py 2024-11-26 20:16:48.833342+00:00
@@ -1,6 +1,11 @@
-from . import aten_ops_converters, ops_evaluators, prims_ops_converters, plugin_ops_converters
+from . import (
+ aten_ops_converters,
+ ops_evaluators,
+ prims_ops_converters,
+ plugin_ops_converters,
+)
from ._conversion import convert_module, interpret_module_to_result
from ._ConversionContext import ConversionContext
from ._ConverterRegistry import * # noqa: F403
from ._TRTInterpreter import * # noqa: F403
from .truncate_double import repair_double_inputs
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py 2024-11-26 20:16:28.728186+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py 2024-11-26 20:16:49.583518+00:00
@@ -1 +1 @@
-from .plugin_generator import PluginCreator
\ No newline at end of file
+from .plugin_generator import PluginCreator
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py 2024-11-26 20:16:28.732186+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py 2024-11-26 20:16:49.650545+00:00
@@ -17,25 +17,28 @@
logger = logging.getLogger(__name__)
TRT_PLUGIN_REGISTRY = trt.get_plugin_registry()
+
@dynamo_tensorrt_converter(torch.ops.torchtrt_ex.elementwise_add.default)
def torchtrt_ex_elementwise_add(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
-):
+):
# logger.debug(f"plugin stuff here2")
# return torch.add(args)
-
+
# How to retrieve a plugin if it is defined elsewhere (e.g. linked library)
- plugin_creator = PluginCreator("elementwise_add_plugin", plugin_namespace="", attrs={})
- TRT_PLUGIN_REGISTRY.register_creator(plugin_creator, "")
-
+ plugin_creator = PluginCreator(
+ "elementwise_add_plugin", plugin_namespace="", attrs={}
+ )
+ TRT_PLUGIN_REGISTRY.register_creator(plugin_creator, "")
+
plugin_creator = TRT_PLUGIN_REGISTRY.get_plugin_creator(
type="elementwise_add_plugin", version="1", plugin_namespace=""
)
assert plugin_creator, f"Unable to find elementwise_add_plugin creator"
@@ -44,45 +47,47 @@
# plugin = plugin_creator.create_plugin(name=name, field_collection=field_configs)
# assert plugin, "Unable to create <PLUGIN_NAME>"
# <GENERATE LINK BETWEEN PLUGIN AND INPUTS>
# <GET INPUTS INTO LIST>
- # <PASS TO PLUGIN>
-
+ # <PASS TO PLUGIN>
+
# return layer.get_output(0)
field_configs = trt.PluginFieldCollection([])
-
- plugin = plugin_creator.create_plugin(name="elementwise_add_plugin", field_collection=field_configs)
+
+ plugin = plugin_creator.create_plugin(
+ name="elementwise_add_plugin", field_collection=field_configs
+ )
assert plugin, "Unable to create CircularPaddingPlugin"
-
+
# input_tensor = args[
# 0
# ] # Arg 0 `torch.ops.torchtrt_ex.triton_circular_pad` is the input tensor
# if not isinstance(input_tensor, trt.ITensor):
# # Freeze input tensor if not TensorRT Tensor already
# input_tensor = get_trt_tensor(ctx, input_tensor, f"{name}_input")
-
+
lhs_dtype = None
rhs_dtype = None
lhs_val = args[0]
rhs_val = args[1]
-
+
if isinstance(lhs_val, TRTTensor):
lhs_dtype = lhs_val.dtype
# is_lhs_trt_tensor = True
if isinstance(rhs_val, TRTTensor):
rhs_dtype = rhs_val.dtype
# is_rhs_trt_tensor = True
-
+
print(lhs_dtype)
-
+
lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype)
rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype)
layer = ctx.net.add_plugin_v3(
[lhs_val, rhs_val], [], plugin
) # Add the plugin to the network being constructed
# layer.name = f"automatic-{name}"
return layer.get_output(0)
-# 1. generate plugin for any pytorch op
\ No newline at end of file
+# 1. generate plugin for any pytorch op
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py 2024-11-26 20:16:28.732186+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py 2024-11-26 20:16:49.769861+00:00
@@ -11,64 +11,63 @@
logger = logging.getLogger("CustomPlugin")
_numpy_to_plugin_field_type = {
- np.dtype('int32'): trt.PluginFieldType.INT32,
- np.dtype('int16'): trt.PluginFieldType.INT16,
- np.dtype('int8'): trt.PluginFieldType.INT8,
- np.dtype('bool'): trt.PluginFieldType.INT8,
- np.dtype('int64'): trt.PluginFieldType.INT64,
- np.dtype('float32'): trt.PluginFieldType.FLOAT32,
- np.dtype('float64'): trt.PluginFieldType.FLOAT64,
- np.dtype('float16'): trt.PluginFieldType.FLOAT16
+ np.dtype("int32"): trt.PluginFieldType.INT32,
+ np.dtype("int16"): trt.PluginFieldType.INT16,
+ np.dtype("int8"): trt.PluginFieldType.INT8,
+ np.dtype("bool"): trt.PluginFieldType.INT8,
+ np.dtype("int64"): trt.PluginFieldType.INT64,
+ np.dtype("float32"): trt.PluginFieldType.FLOAT32,
+ np.dtype("float64"): trt.PluginFieldType.FLOAT64,
+ np.dtype("float16"): trt.PluginFieldType.FLOAT16,
}
_built_in_to_plugin_field_type = {
int: trt.PluginFieldType.INT64,
float: trt.PluginFieldType.FLOAT64,
bool: trt.PluginFieldType.INT8,
# str is handled separately, so not needed here
}
+
class Tactic(IntEnum):
TORCH = 1
TRITON = 2
+
class CustomPlugin(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime): # type: ignore[misc]
- def __init__(
- self, plugin_name : str, attrs, phase = None
- ):
+ def __init__(self, plugin_name: str, attrs, phase=None):
# TODO: needs an additional passed in arguments to specify the needs for each plugin
# such as the one here: https://github.com/NVIDIA/TensorRT/blob/40efe7e9f2492657bbc455c4e2876e2ec792b812/samples/python/python_plugin/circ_pad_plugin_multi_tactic.py#L83
trt.IPluginV3.__init__(self)
# Core capability, plugin attributes and behaviors common to both the build and runtime phases of a plugin’s lifetime
trt.IPluginV3OneCore.__init__(self)
# Build capability, plugin attributes and behaviors that the plugin must exhibit for the TensorRT builder.
trt.IPluginV3OneBuild.__init__(self)
# Runtime capability, plugin attributes and behaviors that the plugin must exhibit for it to be executable
- trt.IPluginV3OneRuntime.__init__(self)
-
+ trt.IPluginV3OneRuntime.__init__(self)
+
# <ANY NON TENSOR INPUTS SHOULD BE AN ATTRIBUTE OF THE PLUGIN>
- # setattr(<name of input>, <default value for that type>)
+ # setattr(<name of input>, <default value for that type>)
# self.pads = []
# self.X_shape: List[int] = []
-
- self.num_outputs = 1 # Defined by schema
+
+ self.num_outputs = 1 # Defined by schema
self.plugin_namespace = ""
self.plugin_name = plugin_name
- self.plugin_version = "1"
+ self.plugin_version = "1"
# Set the timing cache ID to prevent unnecessary timing of second plugin instance
self.timing_cache_id = ""
self.attrs = attrs
-
+
self.tactic = None
-
-
- # <GENERATE CODE FOR TAKING A FIELD COLLECTION CONTAINING THE NON TENSOR INPUTS AND SETTING AN ATTR>
+
+ # <GENERATE CODE FOR TAKING A FIELD COLLECTION CONTAINING THE NON TENSOR INPUTS AND SETTING AN ATTR>
# ex.
# TODO: need to parse the field collection here
# if fc is not None:
# assert fc[0].name == "pads"
# self.pads = fc[0].data
@@ -77,14 +76,12 @@
self.phase = phase
def get_capability_interface(self, type):
return self
- def get_output_data_types(
- self, input_types: List[trt.DataType]
- ) -> trt.DataType:
- # WE CAN USE THE FAKE TENSOR IMPLEMENTATION TO FIGURE OUT THE EXPECTED OUTPUT DATA TYPE
+ def get_output_data_types(self, input_types: List[trt.DataType]) -> trt.DataType:
+ # WE CAN USE THE FAKE TENSOR IMPLEMENTATION TO FIGURE OUT THE EXPECTED OUTPUT DATA TYPE
# with torch.fake_tensor():
# <GENERATE FAKE INPUTS OF TYPE INPUT_TYPES>
# fake_outputs = torch.ops.<custom_ns>.<custom_op>(*fake_inputs)
# return fake_outputs[index]
@@ -96,20 +93,20 @@
self,
inputs: List[trt.DimsExprs],
shape_inputs,
exprBuilder: trt.IExprBuilder,
) -> trt.DimsExprs:
-
+
print(inputs)
- # WE NEED TO FIND A WAY TO GO FROM FAKE TENSOR IMPL TO CONSTRUCTING A DIMSEXPR
- # THIS IS SOLVED IN SHAPE PROP IN PYTORCH WHERE SHAPE PROP CAN GIVE SYMINTS THAT ENCODE THE
- # SHAPE MAP.
+ # WE NEED TO FIND A WAY TO GO FROM FAKE TENSOR IMPL TO CONSTRUCTING A DIMSEXPR
+ # THIS IS SOLVED IN SHAPE PROP IN PYTORCH WHERE SHAPE PROP CAN GIVE SYMINTS THAT ENCODE THE
+ # SHAPE MAP.
output_dims = trt.DimsExprs(inputs[0])
return [output_dims]
-
+
def get_fields_to_serialize(self):
# should be passed in as another argument
field_names = []
for key, value in self.attrs.items():
@@ -149,11 +146,11 @@
self.X_shape = np.zeros((len(X_dims),))
for i in range(len(X_dims)):
self.X_shape[i] = X_dims[i]
def supports_format_combination(self, pos, in_out, num_inputs):
- return
+ return
assert num_inputs == 1
assert pos < len(in_out)
desc = in_out[pos].desc
if desc.format != trt.TensorFormat.LINEAR:
@@ -166,11 +163,10 @@
# output should have the same type as the input
if pos == 1:
return in_out[0].desc.type == desc.type
assert False
-
def enqueue(
self,
input_desc: List[trt.PluginTensorDesc],
output_desc: List[trt.PluginTensorDesc],
@@ -180,40 +176,56 @@
stream: int,
) -> None:
# input and output memory handling
input_mems = [None] * (len(inputs))
- for i in range(len(inputs)):
- input_mems[i] = cp.cuda.UnownedMemory(inputs[i], np.prod(input_desc[i].dims) * cp.dtype(trt.nptype(input_desc[i].type)).itemsize, self)
+ for i in range(len(inputs)):
+ input_mems[i] = cp.cuda.UnownedMemory(
+ inputs[i],
+ np.prod(input_desc[i].dims)
+ * cp.dtype(trt.nptype(input_desc[i].type)).itemsize,
+ self,
+ )
output_mems = [None] * (len(outputs))
for i in range(len(outputs)):
- output_mems[i] = cp.cuda.UnownedMemory(outputs[i], np.prod(output_desc[i].dims) * cp.dtype(trt.nptype(output_desc[i].type)).itemsize, self)
-
+ output_mems[i] = cp.cuda.UnownedMemory(
+ outputs[i],
+ np.prod(output_desc[i].dims)
+ * cp.dtype(trt.nptype(output_desc[i].type)).itemsize,
+ self,
+ )
input_data = [None] * ((len(inputs)))
for i in range(len(inputs)):
- input_data[i] = cp.ndarray(tuple(input_desc[i].dims), dtype=input_desc[i].type, memptr = cp.cuda.MemoryPointer(input_mems[i], 0))
+ input_data[i] = cp.ndarray(
+ tuple(input_desc[i].dims),
+ dtype=input_desc[i].type,
+ memptr=cp.cuda.MemoryPointer(input_mems[i], 0),
+ )
output_data = [None] * ((len(outputs)))
for i in range(len(outputs)):
- output_data[i] = cp.ndarray((np.prod(output_desc[i].dims)), dtype = output_desc[i].type, memptr = cp.cuda.MemoryPointer(output_mems[i], 0))
-
- #TODO: This is just for a simple case for elementwise operations
+ output_data[i] = cp.ndarray(
+ (np.prod(output_desc[i].dims)),
+ dtype=output_desc[i].type,
+ memptr=cp.cuda.MemoryPointer(output_mems[i], 0),
+ )
+
+ # TODO: This is just for a simple case for elementwise operations
# using Torch implementation for now
- input_torch_0 = torch.as_tensor(input_data[0], device='cuda')
- input_torch_1 = torch.as_tensor(input_data[1], device='cuda')
+ input_torch_0 = torch.as_tensor(input_data[0], device="cuda")
+ input_torch_1 = torch.as_tensor(input_data[1], device="cuda")
output = torch.ops.torchtrt_ex.elementwise_add(input_torch_0, input_torch_1)
cp.copyto(output_data, output)
-
def attach_to_context(self, context):
return self.clone()
-
+
def get_valid_tactics(self):
return [int(Tactic.TORCH), int(Tactic.TRITON)]
def set_tactic(self, tactic):
self.tactic = Tactic(tactic)
@@ -226,17 +238,17 @@
cloned_plugin.__dict__.update(self.__dict__)
return cloned_plugin
class PluginCreator(trt.IPluginCreatorV3One): # type: ignore[misc]
- def __init__(self, plugin_name : str, plugin_namespace : str, attrs):
- trt.IPluginCreatorV3One.__init__(self)
+ def __init__(self, plugin_name: str, plugin_namespace: str, attrs):
+ trt.IPluginCreatorV3One.__init__(self)
self.name = plugin_name
self.plugin_namespace = plugin_namespace
self.plugin_version = "1"
-
+
field_names = []
for name, (builtin, type_) in attrs.items():
if builtin:
if type_ is str:
field_names.append(
@@ -259,15 +271,12 @@
)
)
self.field_names = trt.PluginFieldCollection(field_names)
- def create_plugin(
- self, name: str, field_collection, phase=None
- ) -> CustomPlugin:
-
-
+ def create_plugin(self, name: str, field_collection, phase=None) -> CustomPlugin:
+
attrs = {}
# for f in fc:
# if f.name not in desc.input_attrs:
# raise AssertionError(
# f"Unexpected attribute {f.name} provided to create_plugin. Expected one of {desc.input_attrs.keys()}."
@@ -275,10 +284,9 @@
# if _is_numpy_array(desc.input_attrs[f.name]):
# attrs[f.name] = f.data.astype(_infer_numpy_type(desc.input_attrs[f.name]))
# else:
# attrs[f.name] = desc.input_attrs[f.name](f.data)
-
+
custom_plugin = CustomPlugin(name, attrs)
-
+
return custom_plugin
-
ejmoney1
approved these changes
Dec 3, 2024
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
cla signed
component: api [Python]
Issues re: Python API
component: build system
Issues re: Build system
component: conversion
Issues re: Conversion stage
component: dynamo
Issues relating to the `torch.compile` or `torch._dynamo.export` paths
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR implements the automatic plugin feature.
Please delete options that are not relevant and/or add your own.
Checklist: