Skip to content

Commit cddf2ca

Browse files
cccclaifacebook-github-bot
authored andcommitted
[Reland][pytorch] Patch the _is_conv_node function
Summary: Add the conv padding ops in pytorch, the corresponding pr in torch ao is pytorch/ao#2257 Test Plan: ``` buck test 'fbcode//mode/opt' fbcode//caffe2/test:quantization_pt2e -- --exact 'caffe2/test:quantization_pt2e - test_conv_padding_bn_relu (quantization.pt2e.test_quantize_pt2e.TestQuantizePT2E)' ``` Differential Revision: D75494468
1 parent e79790e commit cddf2ca

File tree

3 files changed

+122
-5
lines changed

3 files changed

+122
-5
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2333,6 +2333,118 @@ def validate(self, model: torch.fx.GraphModule) -> None:
23332333
node_list,
23342334
)
23352335

2336+
def test_conv_padding_bn_relu(self):
2337+
class BackendAQuantizer(Quantizer):
2338+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
2339+
act_qspec = QuantizationSpec(
2340+
dtype=torch.uint8,
2341+
quant_min=0,
2342+
quant_max=255,
2343+
qscheme=torch.per_tensor_affine,
2344+
is_dynamic=False,
2345+
observer_or_fake_quant_ctr=observer.default_observer,
2346+
)
2347+
weight_qspec = QuantizationSpec(
2348+
dtype=torch.int8,
2349+
quant_min=-128,
2350+
quant_max=127,
2351+
qscheme=torch.per_tensor_affine,
2352+
is_dynamic=False,
2353+
observer_or_fake_quant_ctr=observer.default_weight_observer,
2354+
)
2355+
bias_qspec = QuantizationSpec(
2356+
dtype=torch.float32,
2357+
is_dynamic=False,
2358+
observer_or_fake_quant_ctr=observer.PlaceholderObserver,
2359+
)
2360+
2361+
for n in model.graph.nodes:
2362+
if (
2363+
n.op != "call_function"
2364+
or n.target != torch.ops.aten.relu.default
2365+
):
2366+
continue
2367+
relu_node = n
2368+
n = n.args[0]
2369+
2370+
# Check for any of the conv operations
2371+
conv_ops = [
2372+
torch.ops.aten.conv1d.padding,
2373+
torch.ops.aten.conv2d.padding,
2374+
torch.ops.aten.conv3d.padding,
2375+
]
2376+
if n.op != "call_function" or n.target not in conv_ops:
2377+
continue
2378+
2379+
conv_node = n
2380+
input_act = conv_node.args[0]
2381+
weight = conv_node.args[1]
2382+
bias = conv_node.args[2]
2383+
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
2384+
input_qspec_map={
2385+
input_act: act_qspec,
2386+
weight: weight_qspec,
2387+
bias: bias_qspec,
2388+
},
2389+
_annotated=True,
2390+
)
2391+
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
2392+
output_qspec=act_qspec,
2393+
_annotated=True,
2394+
)
2395+
2396+
def validate(self, model: torch.fx.GraphModule) -> None:
2397+
pass
2398+
2399+
# Test cases for Conv1d, Conv2d, Conv3d
2400+
test_cases = [
2401+
{
2402+
"dim": 1,
2403+
"example_input": torch.randn(1, 3, 5),
2404+
"conv_op": torch.ops.aten.conv1d.padding,
2405+
},
2406+
{
2407+
"dim": 2,
2408+
"example_input": torch.randn(1, 3, 5, 5),
2409+
"conv_op": torch.ops.aten.conv2d.padding,
2410+
},
2411+
{
2412+
"dim": 3,
2413+
"example_input": torch.randn(1, 3, 5, 5, 5),
2414+
"conv_op": torch.ops.aten.conv3d.padding,
2415+
},
2416+
]
2417+
2418+
for test_case in test_cases:
2419+
with self.subTest(dim=test_case["dim"]):
2420+
model = TestHelperModules(
2421+
relu=True,
2422+
dim=test_case["dim"],
2423+
bn=True,
2424+
bias=True,
2425+
padding="same" # This will trigger the .padding variants
2426+
).eval()
2427+
2428+
node_occurrence = {
2429+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
2430+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
2431+
}
2432+
node_list = [
2433+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2434+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2435+
test_case["conv_op"],
2436+
torch.ops.aten.relu.default,
2437+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
2438+
]
2439+
2440+
self._test_quantizer(
2441+
model,
2442+
test_case["example_input"],
2443+
BackendAQuantizer(),
2444+
node_occurrence,
2445+
node_list,
2446+
)
2447+
23362448
def test_multi_users_without_output_observer(self):
23372449
"""
23382450
Test the case in which a node is used by multiple users,

torch/ao/quantization/pt2e/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,11 @@ def _is_conv_node(n: Node):
167167
"""
168168
return n.op == "call_function" and n.target in [
169169
torch.ops.aten.conv1d.default,
170+
torch.ops.aten.conv1d.padding,
170171
torch.ops.aten.conv2d.default,
172+
torch.ops.aten.conv2d.padding,
173+
torch.ops.aten.conv3d.default,
174+
torch.ops.aten.conv3d.padding,
171175
]
172176

173177

torch/testing/_internal/common_quantization.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3183,13 +3183,14 @@ def forward(self, x):
31833183
x = self.adaptive_avg_pool2d(x)
31843184
return x
31853185

3186+
31863187
class ConvWithBNRelu(torch.nn.Module):
3187-
def __init__(self, relu, dim=2, bn=True, bias=True):
3188+
def __init__(self, relu, dim=2, bn=True, bias=True, padding=0):
31883189
super().__init__()
3189-
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d}
3190-
bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d}
3191-
self.conv = convs[dim](3, 3, 3, bias=bias)
3192-
3190+
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
3191+
bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
3192+
self.conv = convs[dim](3, 3, 3, bias=bias, padding=padding)
3193+
31933194
if bn:
31943195
self.bn = bns[dim](3)
31953196
else:

0 commit comments

Comments
 (0)