1+ import argparse
2+ from typing import Tuple , Union
3+
4+
5+ import tensorrt as trt
6+ import tensorrt .plugin as trtp
7+ import torch
8+ import torch_tensorrt
9+ import triton
10+ import triton .language as tl
11+
12+
13+ trt_logger = trt .Logger (trt .Logger .VERBOSE )
14+
15+
16+ @triton .jit
17+ def add_one_kernel (x_ptr , n_elements , y_ptr , BLOCK_SIZE : tl .constexpr ):
18+ pid = tl .program_id (0 )
19+ block_start = pid * BLOCK_SIZE
20+ offsets = block_start + tl .arange (0 , BLOCK_SIZE )
21+ mask = offsets < n_elements
22+ x = tl .load (x_ptr + offsets , mask = mask )
23+ output = x + 1
24+ tl .store (y_ptr + offsets , output , mask = mask )
25+
26+
27+ @torch .library .custom_op ("my::add_one" , mutates_args = ()) # type: ignore[misc]
28+ def add_one (
29+ X : torch .Tensor
30+ ) -> torch .Tensor :
31+ # Ensure the tensors are on the GPU
32+ assert X .is_cuda
33+
34+ # Create output tensor
35+ Y = torch .empty_like (X )
36+
37+ # Define block size
38+ BLOCK_SIZE = 256
39+
40+ # Grid of programs
41+ grid = lambda meta : (triton .cdiv (X .numel (), meta ["BLOCK_SIZE" ]),)
42+
43+ # Launch the kernel
44+ add_one_kernel [grid ](X , X .numel (), Y , BLOCK_SIZE = BLOCK_SIZE )
45+
46+ return Y
47+
48+
49+ @torch .library .register_fake ("my::add_one" )
50+ def _ (X : torch .Tensor ) -> torch .Tensor :
51+ return X
52+
53+
54+ # torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
55+ # "my::add_one"
56+ # )
57+
58+ @trtp .register ("my::add_one" )
59+ def add_plugin_desc (X : trtp .TensorDesc ) -> Tuple [trtp .TensorDesc ]:
60+ return X .like ()
61+
62+ @trtp .aot_impl ("my::add_one" )
63+ def add_plugin_aot_impl (
64+ X : trtp .TensorDesc , outputs : Tuple [trtp .TensorDesc ], tactic : int
65+ ) -> Tuple [Union [str , bytes ], Union [str , bytes ], trtp .KernelLaunchParams , trtp .SymExprs ]:
66+
67+
68+ type_str = "fp32" if X .dtype == trt .float32 else "fp16"
69+
70+ block_size = 256
71+ src = triton .compiler .ASTSource (
72+ fn = add_one_kernel ,
73+ signature = {
74+ "x_ptr" : f"*{ type_str } " ,
75+ "n_elements" : "i32" ,
76+ "y_ptr" : f"*{ type_str } " ,
77+ "BLOCK_SIZE" : "constexpr" ,
78+ },
79+ constants = {
80+ "BLOCK_SIZE" : block_size ,
81+ },
82+ )
83+
84+ compiled_kernel = triton .compile (src )
85+
86+ N = X .shape_expr .numel ()
87+ launch_params = trtp .KernelLaunchParams ()
88+
89+ # grid dims
90+ launch_params .grid_x = trtp .cdiv (N , block_size )
91+ # block dims
92+ launch_params .block_x = compiled_kernel .metadata .num_warps * 32
93+ # shared memory
94+ launch_params .shared_mem = compiled_kernel .metadata .shared
95+
96+ extra_args = trtp .SymIntExprs (1 )
97+ extra_args [0 ] = trtp .SymInt32 (N )
98+
99+ return (
100+ compiled_kernel .metadata .name ,
101+ compiled_kernel .asm ["ptx" ],
102+ launch_params ,
103+ extra_args ,
104+ )
105+
106+ torch_tensorrt .dynamo .conversion .plugins .generate_plugin_converter (
107+ "my::add_one" ,
108+ supports_dynamic_shapes = False ,
109+ requires_output_allocator = False ,
110+ aot = True ,
111+ )
112+
113+
114+ class MyModel (torch .nn .Module ):
115+ def __init__ (self ):
116+ super ().__init__ ()
117+
118+ def forward (self , X : torch .Tensor ) -> torch .Tensor :
119+ res = torch .ops .my .add_one .default (X )
120+
121+ return res
122+
123+
124+ if __name__ == "__main__" :
125+
126+ parser = argparse .ArgumentParser ()
127+ parser .add_argument (
128+ "--aot" , action = "store_true" , help = "Try to use AOT compilation" , default = False
129+ )
130+ args = parser .parse_args ()
131+
132+
133+
134+ my_model = MyModel ().to ("cuda" )
135+ m = torch .full ((64 , 64 ), 2 , device = "cuda" , dtype = torch .float )
136+
137+ # This works!
138+ assert my_model (X = m )[0 ][0 ] == 3.0
139+
140+
141+ with torch_tensorrt .logging .debug ():
142+ trt_inputs = [m ]
143+ model_trt = torch_tensorrt .compile (
144+ my_model ,
145+ inputs = trt_inputs ,
146+ debug = True ,
147+ min_block_size = 1 ,
148+ )
149+ print ("Model compiled successfully!" )
150+ print ("Running inference with compiled model..." )
151+ for i in range (10 ):
152+ res = model_trt (m )
153+ assert torch .allclose (res , my_model (m )), "Results do not match!"
154+
155+ print ("Inference successful!" )
156+ print (res )
0 commit comments