Skip to content

Commit d61815c

Browse files
dulinrileypytorchmergebot
authored andcommitted
[torch][ao] Use returned model from Quantizer.transform_for_annotation in prepare_pt2e (pytorch#132893)
Summary: The Quantizer subclass can return a new model from `transform_for_annotation`, and this is common if it uses any ExportPass subclass which does not mutate in-place. Use the returned model instead of assuming its the same. Differential Revision: D60869676 Pull Request resolved: pytorch#132893 Approved by: https://github.com/jerryzh168
1 parent 1371c42 commit d61815c

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1500,9 +1500,14 @@ class TestQuantizer(Quantizer):
15001500
def transform_for_annotation(
15011501
self, model: torch.fx.GraphModule
15021502
) -> torch.fx.GraphModule:
1503-
for n in model.graph.nodes:
1503+
# Make a copy of the graph to ensure that we are using the
1504+
# return value of this function.
1505+
graph = torch.fx.Graph()
1506+
graph.graph_copy(model.graph, {})
1507+
for n in graph.nodes:
15041508
if n.target == torch.ops.aten.add.Tensor:
15051509
n.target = torch.ops.aten.mul.Tensor
1510+
model = torch.fx.GraphModule(model, graph)
15061511
return model
15071512

15081513
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:

torch/ao/quantization/quantize_pt2e.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def calibrate(model, data_loader):
9696
# to be quantized before fusion
9797
# TODO: (maybe) rewrite this with subgraph_rewriter
9898
_fuse_conv_bn_(model)
99-
quantizer.transform_for_annotation(model)
99+
model = quantizer.transform_for_annotation(model)
100100
quantizer.annotate(model)
101101
quantizer.validate(model)
102102
model = prepare(model, node_name_to_scope, is_qat=False)
@@ -165,7 +165,7 @@ def train_loop(model, train_data):
165165
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e")
166166
original_graph_meta = model.meta
167167
node_name_to_scope = _get_node_name_to_scope(model)
168-
quantizer.transform_for_annotation(model)
168+
model = quantizer.transform_for_annotation(model)
169169
quantizer.annotate(model)
170170
quantizer.validate(model)
171171
# Perform fusion after annotate to avoid quantizing ops in the new

0 commit comments

Comments
 (0)