@@ -52,47 +52,50 @@ def add_plugin_desc(X: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]:
5252 return X .like ()
5353
5454
55- # @trtp.aot_impl("my::add_one")
56- # def add_plugin_aot_impl(
57- # X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int
58- # ) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:
59- # type_str = "fp32" if X.dtype == trt.float32 else "fp16"
60-
61- # block_size = 256
62- # src = triton.compiler.ASTSource(
63- # fn=add_one_kernel,
64- # signature={
65- # "x_ptr": f"*{type_str}",
66- # "n_elements": "i32",
67- # "y_ptr": f"*{type_str}",
68- # "BLOCK_SIZE": "constexpr",
69- # },
70- # constants={
71- # "BLOCK_SIZE": block_size,
72- # },
73- # )
74-
75- # compiled_kernel = triton.compile(src)
76-
77- # N = X.shape_expr.numel()
78- # launch_params = trtp.KernelLaunchParams()
79-
80- # # grid dims
81- # launch_params.grid_x = trtp.cdiv(N, block_size)
82- # # block dims
83- # launch_params.block_x = compiled_kernel.metadata.num_warps * 32
84- # # shared memory
85- # launch_params.shared_mem = compiled_kernel.metadata.shared
86-
87- # extra_args = trtp.SymIntExprs(1)
88- # extra_args[0] = trtp.SymInt32(N)
89-
90- # return (
91- # compiled_kernel.metadata.name,
92- # compiled_kernel.asm["ptx"],
93- # launch_params,
94- # extra_args,
95- # )
55+ @trtp .aot_impl ("my::add_one" )
56+ def add_plugin_aot_impl (
57+ X : trtp .TensorDesc , outputs : Tuple [trtp .TensorDesc ], tactic : int
58+ ) -> Tuple [
59+ Union [str , bytes ], Union [str , bytes ], trtp .KernelLaunchParams , trtp .SymExprs
60+ ]:
61+ type_str = "fp32" if X .dtype == trt .float32 else "fp16"
62+
63+ block_size = 256
64+ src = triton .compiler .ASTSource (
65+ fn = add_one_kernel ,
66+ signature = {
67+ "x_ptr" : f"*{ type_str } " ,
68+ "n_elements" : "i32" ,
69+ "y_ptr" : f"*{ type_str } " ,
70+ "BLOCK_SIZE" : "constexpr" ,
71+ },
72+ constants = {
73+ "BLOCK_SIZE" : block_size ,
74+ },
75+ )
76+
77+ compiled_kernel = triton .compile (src )
78+
79+ N = X .shape_expr .numel ()
80+ launch_params = trtp .KernelLaunchParams ()
81+
82+ # grid dims
83+ launch_params .grid_x = trtp .cdiv (N , block_size )
84+ # block dims
85+ launch_params .block_x = compiled_kernel .metadata .num_warps * 32
86+ # shared memory
87+ launch_params .shared_mem = compiled_kernel .metadata .shared
88+
89+ extra_args = trtp .SymIntExprs (1 )
90+ extra_args [0 ] = trtp .SymInt32 (N )
91+
92+ return (
93+ compiled_kernel .metadata .name ,
94+ compiled_kernel .asm ["ptx" ],
95+ launch_params ,
96+ extra_args ,
97+ )
98+
9699
97100torch_tensorrt .dynamo .conversion .plugins .generate_plugin_converter (
98101 "my::add_one" ,
@@ -113,7 +116,6 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
113116
114117
115118if __name__ == "__main__" :
116-
117119 parser = argparse .ArgumentParser ()
118120 parser .add_argument (
119121 "--aot" , action = "store_true" , help = "Try to use AOT compilation" , default = False
@@ -123,7 +125,6 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
123125 my_model = MyModel ().to ("cuda" )
124126 m = torch .full ((64 , 64 ), 2 , device = "cuda" , dtype = torch .float )
125127
126- # This works!
127128 assert my_model (X = m )[0 ][0 ] == 3.0
128129
129130 with torch_tensorrt .logging .debug ():
@@ -141,4 +142,3 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
141142 assert torch .allclose (res , my_model (m )), "Results do not match!"
142143
143144 print ("Inference successful!" )
144- print (res )
0 commit comments