Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Initial implementation for automatic plugin #3301

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions examples/dynamo/automatic_plugin/custom_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import triton
import triton.language as tl

@triton.jit
def elementwise_add_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):
# Program ID determines the block of data each thread will process
pid = tl.program_id(0)
# Compute the range of elements that this thread block will work on
block_start = pid * BLOCK_SIZE
# Range of indices this thread will handle
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Load elements from the X and Y tensors
x_vals = tl.load(X + offsets)
y_vals = tl.load(Y + offsets)
# Perform the element-wise multiplication
z_vals = x_vals + y_vals
# Store the result in Z
tl.store(Z + offsets, z_vals)


import torch
from torch.library import custom_op


@custom_op("torchtrt_ex::elementwise_add", mutates_args=()) # type: ignore[misc]
def elementwise_add(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
# Ensure the tensors are on the GPU
assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device."
assert X.shape == Y.shape, "Tensors must have the same shape."

# Create output tensor
Z = torch.empty_like(X)

# Define block size
BLOCK_SIZE = 1024

# Grid of programs
grid = lambda meta: (X.numel() // meta['BLOCK_SIZE'],)

# Launch the kernel
elementwise_add_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE)

return Z


# Using the module in PyTorch
# X = torch.randn(1024, device='cuda', requires_grad=True)
# Y = torch.randn(1024, device='cuda', requires_grad=True)
# X = torch.full((128, 128), 2, device='cuda',)
# Y = torch.full((128, 128), 2, device='cuda',)
# # elementwise_mul_op = ElementwiseMulModule()
# Z = torch.ops.torchtrt_ex.elementwise_add(X, Y)
# print(Z)
# print(X + Y)
# print(X)
# print(Y)
# print(Z)
# print(X+Y)
# Z.sum().backward()


from torch import nn


class MyModel(nn.Module): # type: ignore[misc]
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
z = torch.mul(x, y)
res = torch.ops.torchtrt_ex.elementwise_add(x, z)

return res


my_model = MyModel().to("cuda")
m = torch.full((64, 64), 2, device='cuda',)
n = torch.full((64, 64), 3, device='cuda',)
# print(torch.ops.torchtrt_ex.elementwise_add(m, n))
# print(my_model.forward(m, n))


@torch.library.register_fake("torchtrt_ex::elementwise_add")
def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x

import torch_tensorrt as torchtrt


with torchtrt.logging.info():
model_trt = torchtrt.compile(my_model, inputs=[m, n], debug=True, min_block_size=1)
res = model_trt(m, n)
print(res)
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import aten_ops_converters, ops_evaluators, prims_ops_converters
from . import aten_ops_converters, ops_evaluators, prims_ops_converters, plugin_ops_converters
from ._conversion import convert_module, interpret_module_to_result
from ._ConversionContext import ConversionContext
from ._ConverterRegistry import * # noqa: F403
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .plugin_generator import PluginCreator
284 changes: 284 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
import tensorrt as trt
import cupy as cp
import torch
import numpy as np

import logging


from enum import IntEnum
from typing import List


logger = logging.getLogger("CustomPlugin")

_numpy_to_plugin_field_type = {
np.dtype('int32'): trt.PluginFieldType.INT32,
np.dtype('int16'): trt.PluginFieldType.INT16,
np.dtype('int8'): trt.PluginFieldType.INT8,
np.dtype('bool'): trt.PluginFieldType.INT8,
np.dtype('int64'): trt.PluginFieldType.INT64,
np.dtype('float32'): trt.PluginFieldType.FLOAT32,
np.dtype('float64'): trt.PluginFieldType.FLOAT64,
np.dtype('float16'): trt.PluginFieldType.FLOAT16
}

_built_in_to_plugin_field_type = {
int: trt.PluginFieldType.INT64,
float: trt.PluginFieldType.FLOAT64,
bool: trt.PluginFieldType.INT8,
# str is handled separately, so not needed here
}

class Tactic(IntEnum):
TORCH = 1
TRITON = 2

class CustomPlugin(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime): # type: ignore[misc]
def __init__(
self, plugin_name : str, attrs, phase = None
):
# TODO: needs an additional passed in arguments to specify the needs for each plugin
# such as the one here: https://github.com/NVIDIA/TensorRT/blob/40efe7e9f2492657bbc455c4e2876e2ec792b812/samples/python/python_plugin/circ_pad_plugin_multi_tactic.py#L83
trt.IPluginV3.__init__(self)
# Core capability, plugin attributes and behaviors common to both the build and runtime phases of a plugin’s lifetime
trt.IPluginV3OneCore.__init__(self)
# Build capability, plugin attributes and behaviors that the plugin must exhibit for the TensorRT builder.
trt.IPluginV3OneBuild.__init__(self)
# Runtime capability, plugin attributes and behaviors that the plugin must exhibit for it to be executable
trt.IPluginV3OneRuntime.__init__(self)

# <ANY NON TENSOR INPUTS SHOULD BE AN ATTRIBUTE OF THE PLUGIN>
# setattr(<name of input>, <default value for that type>)
# self.pads = []
# self.X_shape: List[int] = []

self.num_outputs = 1 # Defined by schema
self.plugin_namespace = ""
self.plugin_name = plugin_name
self.plugin_version = "1"

# Set the timing cache ID to prevent unnecessary timing of second plugin instance
self.timing_cache_id = ""

self.attrs = attrs

self.tactic = None


# <GENERATE CODE FOR TAKING A FIELD COLLECTION CONTAINING THE NON TENSOR INPUTS AND SETTING AN ATTR>
# ex.
# TODO: need to parse the field collection here
# if fc is not None:
# assert fc[0].name == "pads"
# self.pads = fc[0].data

if phase is not None:
self.phase = phase

def get_capability_interface(self, type):
return self

def get_output_data_types(
self, input_types: List[trt.DataType]
) -> trt.DataType:
# WE CAN USE THE FAKE TENSOR IMPLEMENTATION TO FIGURE OUT THE EXPECTED OUTPUT DATA TYPE
# with torch.fake_tensor():
# <GENERATE FAKE INPUTS OF TYPE INPUT_TYPES>
# fake_outputs = torch.ops.<custom_ns>.<custom_op>(*fake_inputs)

# return fake_outputs[index]

# The example case here is simple for experiment
return [input_types[0]]

def get_output_shapes(
self,
inputs: List[trt.DimsExprs],
shape_inputs,
exprBuilder: trt.IExprBuilder,
) -> trt.DimsExprs:

print(inputs)

# WE NEED TO FIND A WAY TO GO FROM FAKE TENSOR IMPL TO CONSTRUCTING A DIMSEXPR
# THIS IS SOLVED IN SHAPE PROP IN PYTORCH WHERE SHAPE PROP CAN GIVE SYMINTS THAT ENCODE THE
# SHAPE MAP.
output_dims = trt.DimsExprs(inputs[0])

return [output_dims]

def get_fields_to_serialize(self):
# should be passed in as another argument
field_names = []

for key, value in self.attrs.items():
if isinstance(value, np.ndarray):
field_names.append(
trt.PluginField(
key,
value,
_numpy_to_plugin_field_type[np.dtype(value.dtype)],
)
)
elif isinstance(value, str):
field_names.append(
trt.PluginField(key, value.encode(), trt.PluginFieldType.CHAR)
)
elif isinstance(value, bytes):
field_names.append(
trt.PluginField(key, value, trt.PluginFieldType.UNKNOWN)
)
else:
field_names.append(
trt.PluginField(
key,
np.array([value]),
_built_in_to_plugin_field_type[type(value)],
)
)

return trt.PluginFieldCollection(field_names)

def configure_plugin(self, inp, out):
pass

def on_shape_change(self, inp, out):
return
X_dims = inp[0].dims
self.X_shape = np.zeros((len(X_dims),))
for i in range(len(X_dims)):
self.X_shape[i] = X_dims[i]

def supports_format_combination(self, pos, in_out, num_inputs):
return
assert num_inputs == 1
assert pos < len(in_out)

desc = in_out[pos].desc
if desc.format != trt.TensorFormat.LINEAR:
return False

# first input should be float16 or float32
if pos == 0:
return desc.type == trt.DataType.FLOAT or desc.type == trt.DataType.HALF

# output should have the same type as the input
if pos == 1:
return in_out[0].desc.type == desc.type

assert False


def enqueue(
self,
input_desc: List[trt.PluginTensorDesc],
output_desc: List[trt.PluginTensorDesc],
inputs,
outputs,
workspace: int,
stream: int,
) -> None:
# input and output memory handling
input_mems = [None] * (len(inputs))

for i in range(len(inputs)):
input_mems[i] = cp.cuda.UnownedMemory(inputs[i], np.prod(input_desc[i].dims) * cp.dtype(trt.nptype(input_desc[i].type)).itemsize, self)

output_mems = [None] * (len(outputs))

for i in range(len(outputs)):
output_mems[i] = cp.cuda.UnownedMemory(outputs[i], np.prod(output_desc[i].dims) * cp.dtype(trt.nptype(output_desc[i].type)).itemsize, self)


input_data = [None] * ((len(inputs)))
for i in range(len(inputs)):
input_data[i] = cp.ndarray(tuple(input_desc[i].dims), dtype=input_desc[i].type, memptr = cp.cuda.MemoryPointer(input_mems[i], 0))

output_data = [None] * ((len(outputs)))
for i in range(len(outputs)):
output_data[i] = cp.ndarray((np.prod(output_desc[i].dims)), dtype = output_desc[i].type, memptr = cp.cuda.MemoryPointer(output_mems[i], 0))

#TODO: This is just for a simple case for elementwise operations
# using Torch implementation for now
input_torch_0 = torch.as_tensor(input_data[0], device='cuda')
input_torch_1 = torch.as_tensor(input_data[1], device='cuda')

output = torch.ops.torchtrt_ex.elementwise_add(input_torch_0, input_torch_1)

cp.copyto(output_data, output)


def attach_to_context(self, context):
return self.clone()

def get_valid_tactics(self):
return [int(Tactic.TORCH), int(Tactic.TRITON)]

def set_tactic(self, tactic):
self.tactic = Tactic(tactic)

# if self.phase == trt.TensorRTPhase.RUNTIME:
# logger.info(f"Best tactic chosen: {self.tactic}")

def clone(self):
cloned_plugin = CustomPlugin(self.plugin_name, self.attrs)
cloned_plugin.__dict__.update(self.__dict__)
return cloned_plugin


class PluginCreator(trt.IPluginCreatorV3One): # type: ignore[misc]
def __init__(self, plugin_name : str, plugin_namespace : str, attrs):
trt.IPluginCreatorV3One.__init__(self)

self.name = plugin_name
self.plugin_namespace = plugin_namespace
self.plugin_version = "1"

field_names = []
for name, (builtin, type_) in attrs.items():
if builtin:
if type_ is str:
field_names.append(
trt.PluginField(name, b"", trt.PluginFieldType.CHAR)
)
elif type_ is bytes:
field_names.append(
trt.PluginField(name, b"", trt.PluginFieldType.UNKNOWN)
)
else:
field_names.append(
trt.PluginField(
name, np.array([]), _built_in_to_plugin_field_type[type_]
)
)
else:
field_names.append(
trt.PluginField(
name, np.array([]), _numpy_to_plugin_field_type[np.dtype(type_)]
)
)

self.field_names = trt.PluginFieldCollection(field_names)

def create_plugin(
self, name: str, field_collection, phase=None
) -> CustomPlugin:


attrs = {}
# for f in fc:
# if f.name not in desc.input_attrs:
# raise AssertionError(
# f"Unexpected attribute {f.name} provided to create_plugin. Expected one of {desc.input_attrs.keys()}."
# )

# if _is_numpy_array(desc.input_attrs[f.name]):
# attrs[f.name] = f.data.astype(_infer_numpy_type(desc.input_attrs[f.name]))
# else:
# attrs[f.name] = desc.input_attrs[f.name](f.data)

custom_plugin = CustomPlugin(name, attrs)

return custom_plugin

Loading
Loading