Skip to content

Patch the _is_conv_node function #2257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -2571,6 +2571,124 @@ def forward(self, x):
node_list,
)

def test_conv_padding_bn_relu(self):
class BackendAQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
act_qspec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_observer,
)
weight_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_weight_observer,
)
bias_qspec = QuantizationSpec(
dtype=torch.float32,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.PlaceholderObserver,
)

for n in model.graph.nodes:
if (
n.op != "call_function"
or n.target != torch.ops.aten.relu.default
):
continue
relu_node = n
n = n.args[0]

# Check for any of the conv operations
conv_ops = [
torch.ops.aten.conv1d.padding,
torch.ops.aten.conv2d.padding,
torch.ops.aten.conv3d.padding,
]
if n.op != "call_function" or n.target not in conv_ops:
continue

conv_node = n
input_act = conv_node.args[0]
weight = conv_node.args[1]
bias = conv_node.args[2]
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
input_act: act_qspec,
weight: weight_qspec,
bias: bias_qspec,
},
_annotated=True,
)
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=act_qspec,
_annotated=True,
)

def validate(self, model: torch.fx.GraphModule) -> None:
pass

# Test cases for Conv1d, Conv2d, Conv3d
test_cases = [
{
"conv_type": torch.nn.Conv1d,
"bn_type": torch.nn.BatchNorm1d,
"example_input": (torch.randn(1, 3, 5),),
"conv_op": torch.ops.aten.conv1d.padding,
},
{
"conv_type": torch.nn.Conv2d,
"bn_type": torch.nn.BatchNorm2d,
"example_input": (torch.randn(1, 3, 5, 5),),
"conv_op": torch.ops.aten.conv2d.padding,
},
{
"conv_type": torch.nn.Conv3d,
"bn_type": torch.nn.BatchNorm3d,
"example_input": (torch.randn(1, 3, 5, 5, 5),),
"conv_op": torch.ops.aten.conv3d.padding,
},
]

for test_case in test_cases:
with self.subTest(conv_type=test_case["conv_type"].__name__):

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = test_case["conv_type"](3, 3, 3, padding="same")
self.bn = test_case["bn_type"](3)

def forward(self, x):
return torch.nn.functional.relu(self.bn(self.conv(x)))

node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
}
node_list = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
test_case["conv_op"],
torch.ops.aten.relu.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
]

model = M().eval()
self._test_quantizer(
model,
test_case["example_input"],
BackendAQuantizer(),
node_occurrence,
node_list,
)

def test_multi_users_without_output_observer(self):
"""
Test the case in which a node is used by multiple users,
Expand Down
3 changes: 3 additions & 0 deletions torchao/quantization/pt2e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,8 +625,11 @@ def _is_conv_node(n: Node):
"""
return n.op == "call_function" and n.target in [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv1d.padding,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv2d.padding,
torch.ops.aten.conv3d.default,
torch.ops.aten.conv3d.padding,
]


Expand Down
Loading