@@ -137,3 +137,33 @@ TEST(Converters, ATenBatchNormShouldUnpackConvertsCorrectly) {
137
137
ASSERT_TRUE (
138
138
torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
139
139
}
140
+
141
+ TEST (Converters, ATenBatchNormHalfConvertsCorrectly) {
142
+ const auto graph = R"IR(
143
+ graph(%input : Tensor, %running_var : Half(32, strides=[1], requires_grad=0, device=cuda:0), %running_mean : Half(32, strides=[1], requires_grad=0, device=cuda:0)):
144
+ %5 : bool = prim::Constant[value=0]()
145
+ %4 : float = prim::Constant[value=0.01]()
146
+ %3 : float = prim::Constant[value=0.001]()
147
+ %2 : bool = prim::Constant[value=1]()
148
+ %8 : Tensor = aten::batch_norm(%input, %running_var, %running_mean, %running_mean, %running_var, %5, %4, %3, %2)
149
+ return (%8))IR" ;
150
+
151
+ auto g = std::make_shared<torch::jit::Graph>();
152
+ torch::jit::parseIR (graph, &*g);
153
+
154
+ auto in = at::randn ({2 , 32 , 5 , 5 }, {at::kCUDA }).to (at::kHalf );
155
+ auto mean = at::ones ({32 }, {at::kCUDA }).to (at::kHalf );
156
+ auto var = at::zeros ({32 }, {at::kCUDA }).to (at::kHalf );
157
+
158
+ auto trt_in = at::clone (in);
159
+ auto trt_mean = at::clone (mean);
160
+ auto trt_var = at::clone (var);
161
+
162
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {mean, var});
163
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in});
164
+
165
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {trt_mean, trt_var});
166
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in}, {nvinfer1::DataType::kHALF });
167
+
168
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-2 ));
169
+ }
0 commit comments