Skip to content

Commit c3cd651

Browse files
committed
feat: enable AOT tensorrt plugin example
1 parent dc36709 commit c3cd651

File tree

2 files changed

+160
-1
lines changed

2 files changed

+160
-1
lines changed

examples/dynamo/aot_plugin.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import argparse
2+
from typing import Tuple, Union
3+
4+
5+
import tensorrt as trt
6+
import tensorrt.plugin as trtp
7+
import torch
8+
import torch_tensorrt
9+
import triton
10+
import triton.language as tl
11+
12+
13+
trt_logger = trt.Logger(trt.Logger.VERBOSE)
14+
15+
16+
@triton.jit
17+
def add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
18+
pid = tl.program_id(0)
19+
block_start = pid * BLOCK_SIZE
20+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
21+
mask = offsets < n_elements
22+
x = tl.load(x_ptr + offsets, mask=mask)
23+
output = x + 1
24+
tl.store(y_ptr + offsets, output, mask=mask)
25+
26+
27+
@torch.library.custom_op("my::add_one", mutates_args=()) # type: ignore[misc]
28+
def add_one(
29+
X: torch.Tensor
30+
) -> torch.Tensor:
31+
# Ensure the tensors are on the GPU
32+
assert X.is_cuda
33+
34+
# Create output tensor
35+
Y = torch.empty_like(X)
36+
37+
# Define block size
38+
BLOCK_SIZE = 256
39+
40+
# Grid of programs
41+
grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),)
42+
43+
# Launch the kernel
44+
add_one_kernel[grid](X, X.numel(), Y, BLOCK_SIZE=BLOCK_SIZE)
45+
46+
return Y
47+
48+
49+
@torch.library.register_fake("my::add_one")
50+
def _(X: torch.Tensor) -> torch.Tensor:
51+
return X
52+
53+
54+
# torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
55+
# "my::add_one"
56+
# )
57+
58+
@trtp.register("my::add_one")
59+
def add_plugin_desc(X: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]:
60+
return X.like()
61+
62+
@trtp.aot_impl("my::add_one")
63+
def add_plugin_aot_impl(
64+
X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int
65+
) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:
66+
67+
68+
type_str = "fp32" if X.dtype == trt.float32 else "fp16"
69+
70+
block_size = 256
71+
src = triton.compiler.ASTSource(
72+
fn=add_one_kernel,
73+
signature={
74+
"x_ptr": f"*{type_str}",
75+
"n_elements": "i32",
76+
"y_ptr": f"*{type_str}",
77+
"BLOCK_SIZE": "constexpr",
78+
},
79+
constants={
80+
"BLOCK_SIZE": block_size,
81+
},
82+
)
83+
84+
compiled_kernel = triton.compile(src)
85+
86+
N = X.shape_expr.numel()
87+
launch_params = trtp.KernelLaunchParams()
88+
89+
# grid dims
90+
launch_params.grid_x = trtp.cdiv(N, block_size)
91+
# block dims
92+
launch_params.block_x = compiled_kernel.metadata.num_warps * 32
93+
# shared memory
94+
launch_params.shared_mem = compiled_kernel.metadata.shared
95+
96+
extra_args = trtp.SymIntExprs(1)
97+
extra_args[0] = trtp.SymInt32(N)
98+
99+
return (
100+
compiled_kernel.metadata.name,
101+
compiled_kernel.asm["ptx"],
102+
launch_params,
103+
extra_args,
104+
)
105+
106+
torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
107+
"my::add_one",
108+
supports_dynamic_shapes=False,
109+
requires_output_allocator=False,
110+
aot=True,
111+
)
112+
113+
114+
class MyModel(torch.nn.Module):
115+
def __init__(self):
116+
super().__init__()
117+
118+
def forward(self, X: torch.Tensor) -> torch.Tensor:
119+
res = torch.ops.my.add_one.default(X)
120+
121+
return res
122+
123+
124+
if __name__ == "__main__":
125+
126+
parser = argparse.ArgumentParser()
127+
parser.add_argument(
128+
"--aot", action="store_true", help="Try to use AOT compilation", default=False
129+
)
130+
args = parser.parse_args()
131+
132+
133+
134+
my_model = MyModel().to("cuda")
135+
m = torch.full((64, 64), 2, device="cuda", dtype=torch.float)
136+
137+
# This works!
138+
assert my_model(X=m)[0][0] == 3.0
139+
140+
141+
with torch_tensorrt.logging.debug():
142+
trt_inputs = [m]
143+
model_trt = torch_tensorrt.compile(
144+
my_model,
145+
inputs=trt_inputs,
146+
debug=True,
147+
min_block_size=1,
148+
)
149+
print("Model compiled successfully!")
150+
print("Running inference with compiled model...")
151+
for i in range(10):
152+
res = model_trt(m)
153+
assert torch.allclose(res, my_model(m)), "Results do not match!"
154+
155+
print("Inference successful!")
156+
print(res)

py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def _generate_plugin_converter(
3131
priority: ConverterPriority = ConverterPriority.STANDARD,
3232
supports_dynamic_shapes: bool = False,
3333
requires_output_allocator: bool = False,
34+
aot: bool = False,
3435
) -> DynamoConverterImplSignature:
3536
torch_target = getattr(getattr(torch.ops, namespace), op_name)
3637
overload_str = overload if overload else ""
@@ -80,7 +81,7 @@ def custom_kernel_converter(
8081
if isinstance(v, torch.fx.immutable_collections.immutable_list):
8182
kwargs[k] = np.array(v)
8283

83-
layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs))
84+
layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs), aot=aot)
8485
assert layer, f"{namespace}::{name} plugin layer was not able to be created"
8586
_LOGGER.debug(
8687
f"Adding generated plugin for {namespace}::{name} to tensorrt network"
@@ -107,6 +108,7 @@ def generate_plugin_converter(
107108
priority: ConverterPriority = ConverterPriority.STANDARD,
108109
supports_dynamic_shapes: bool = False,
109110
requires_output_allocator: bool = False,
111+
aot: bool = False,
110112
) -> DynamoConverterImplSignature:
111113
plugin_ns, plugin_name = plugin_id.split("::")
112114
return _generate_plugin_converter(
@@ -116,4 +118,5 @@ def generate_plugin_converter(
116118
priority=priority,
117119
supports_dynamic_shapes=supports_dynamic_shapes,
118120
requires_output_allocator=requires_output_allocator,
121+
aot=aot,
119122
)

0 commit comments

Comments
 (0)