@@ -2333,6 +2333,118 @@ def validate(self, model: torch.fx.GraphModule) -> None:
2333
2333
node_list ,
2334
2334
)
2335
2335
2336
+ def test_conv_padding_bn_relu (self ):
2337
+ class BackendAQuantizer (Quantizer ):
2338
+ def annotate (self , model : torch .fx .GraphModule ) -> torch .fx .GraphModule :
2339
+ act_qspec = QuantizationSpec (
2340
+ dtype = torch .uint8 ,
2341
+ quant_min = 0 ,
2342
+ quant_max = 255 ,
2343
+ qscheme = torch .per_tensor_affine ,
2344
+ is_dynamic = False ,
2345
+ observer_or_fake_quant_ctr = observer .default_observer ,
2346
+ )
2347
+ weight_qspec = QuantizationSpec (
2348
+ dtype = torch .int8 ,
2349
+ quant_min = - 128 ,
2350
+ quant_max = 127 ,
2351
+ qscheme = torch .per_tensor_affine ,
2352
+ is_dynamic = False ,
2353
+ observer_or_fake_quant_ctr = observer .default_weight_observer ,
2354
+ )
2355
+ bias_qspec = QuantizationSpec (
2356
+ dtype = torch .float32 ,
2357
+ is_dynamic = False ,
2358
+ observer_or_fake_quant_ctr = observer .PlaceholderObserver ,
2359
+ )
2360
+
2361
+ for n in model .graph .nodes :
2362
+ if (
2363
+ n .op != "call_function"
2364
+ or n .target != torch .ops .aten .relu .default
2365
+ ):
2366
+ continue
2367
+ relu_node = n
2368
+ n = n .args [0 ]
2369
+
2370
+ # Check for any of the conv operations
2371
+ conv_ops = [
2372
+ torch .ops .aten .conv1d .padding ,
2373
+ torch .ops .aten .conv2d .padding ,
2374
+ torch .ops .aten .conv3d .padding ,
2375
+ ]
2376
+ if n .op != "call_function" or n .target not in conv_ops :
2377
+ continue
2378
+
2379
+ conv_node = n
2380
+ input_act = conv_node .args [0 ]
2381
+ weight = conv_node .args [1 ]
2382
+ bias = conv_node .args [2 ]
2383
+ conv_node .meta ["quantization_annotation" ] = QuantizationAnnotation (
2384
+ input_qspec_map = {
2385
+ input_act : act_qspec ,
2386
+ weight : weight_qspec ,
2387
+ bias : bias_qspec ,
2388
+ },
2389
+ _annotated = True ,
2390
+ )
2391
+ relu_node .meta ["quantization_annotation" ] = QuantizationAnnotation (
2392
+ output_qspec = act_qspec ,
2393
+ _annotated = True ,
2394
+ )
2395
+
2396
+ def validate (self , model : torch .fx .GraphModule ) -> None :
2397
+ pass
2398
+
2399
+ # Test cases for Conv1d, Conv2d, Conv3d
2400
+ test_cases = [
2401
+ {
2402
+ "dim" : 1 ,
2403
+ "example_input" : (torch .randn (1 , 3 , 5 ),),
2404
+ "conv_op" : torch .ops .aten .conv1d .padding ,
2405
+ },
2406
+ {
2407
+ "dim" : 2 ,
2408
+ "example_input" : (torch .randn (1 , 3 , 5 , 5 ),),
2409
+ "conv_op" : torch .ops .aten .conv2d .padding ,
2410
+ },
2411
+ {
2412
+ "dim" : 3 ,
2413
+ "example_input" : (torch .randn (1 , 3 , 5 , 5 , 5 ),),
2414
+ "conv_op" : torch .ops .aten .conv3d .padding ,
2415
+ },
2416
+ ]
2417
+
2418
+ for test_case in test_cases :
2419
+ with self .subTest (dim = test_case ["dim" ]):
2420
+ model = TestHelperModules .ConvWithBNRelu (
2421
+ relu = True ,
2422
+ dim = test_case ["dim" ],
2423
+ bn = True ,
2424
+ bias = True ,
2425
+ padding = "same" , # This will trigger the .padding variants
2426
+ ).eval ()
2427
+
2428
+ node_occurrence = {
2429
+ torch .ops .quantized_decomposed .quantize_per_tensor .default : 2 ,
2430
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default : 3 ,
2431
+ }
2432
+ node_list = [
2433
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2434
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
2435
+ test_case ["conv_op" ],
2436
+ torch .ops .aten .relu .default ,
2437
+ torch .ops .quantized_decomposed .quantize_per_tensor .default ,
2438
+ ]
2439
+
2440
+ self ._test_quantizer (
2441
+ model ,
2442
+ test_case ["example_input" ],
2443
+ BackendAQuantizer (),
2444
+ node_occurrence ,
2445
+ node_list ,
2446
+ )
2447
+
2336
2448
def test_multi_users_without_output_observer (self ):
2337
2449
"""
2338
2450
Test the case in which a node is used by multiple users,
0 commit comments