Skip to content

[XNNPACKQuantizer 2/N][TorchFX] SharedQuantizationSpec support #3385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nncf/common/quantization/quantizer_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,13 @@ def __init__(self) -> None:
self._next_unified_scale_gid = 0
self._next_shared_inputs_gid = 0

def add_independent_quantization_point(self, qp: QuantizationPointBase) -> None:
def add_independent_quantization_point(self, qp: QuantizationPointBase) -> int:
if self.quantization_points.keys():
new_id = max(self.quantization_points.keys()) + 1
else:
new_id = 0
self.quantization_points[new_id] = qp
return new_id

def register_unified_scale_group(self, qp_group: List[QuantizationPointId]) -> int:
for qp_id in qp_group:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@


from collections import defaultdict
from typing import Dict, List, Tuple, Union
from typing import Any, Dict, List, Tuple, Union

import torch
import torch.fx
from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_group_id
from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_qspec
from torch.ao.quantization.quantizer import Quantizer as TorchAOQuantizer
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec
from torch.ao.quantization.quantizer.quantizer import QuantizationSpecBase
from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec

import nncf
from nncf.common.graph.graph import NNCFGraph
from nncf.common.logging import nncf_logger
from nncf.common.quantization.quantizer_setup import ActivationQuantizationInsertionPoint
from nncf.common.quantization.quantizer_setup import QuantizationPointBase
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizationPoint
Expand Down Expand Up @@ -73,6 +73,15 @@ def _get_quantization_points(
annotated_model: torch.fx.GraphModule,
qconfig: QuantizerConfig,
) -> List[QuantizationPointBase]:
"""
Creates quantization points based on the nodes and edges.

:param from_node: The originating node in the computation graph.
:param to_nodes: The list of destination nodes of the from_node.
:param annotated_model: The torch.fx.GraphModule instance.
:param qconfig: The torch.ao quantization configuration.
:return: A list of NNCF quantization points.
"""
to_n = to_nodes[0]
if from_node.op == "get_attr":
_, metatype = GraphConverter.get_node_type_and_metatype(to_n, annotated_model)
Expand All @@ -95,78 +104,102 @@ def _get_quantization_points(
return qps

@staticmethod
def _get_node_args(node: torch.fx.Node):
def _get_node_args(node: torch.fx.Node) -> Tuple[Any, ...]:
"""
Correctly retrieves arguments of the given node.

:param node: The given node.
:return: The arguments of the given node.
"""
if node.target == torch.ops.aten.cat.default:
return node.args[0]
return node.args

@staticmethod
def get_quantizer_config_from_annotated_model(annotated_model: torch.fx.GraphModule) -> SingleConfigQuantizerSetup:
edge_or_node_to_qspec = _get_edge_or_node_to_qspec(annotated_model)

q_map = defaultdict(list)
for edge, qspec in edge_or_node_to_qspec.items():
if not isinstance(edge, tuple):
continue
from_n, to_n = edge
q_map[from_n].append(to_n)
def get_quantizer_config_from_annotated_model(annotated: torch.fx.GraphModule) -> SingleConfigQuantizerSetup:
edge_or_node_to_qspec = _get_edge_or_node_to_qspec(annotated)
# Node means all output edges should be quantized.
# Edge means only one edge should be quantized.
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)

group_id_vs_edges = defaultdict(set)
group_id_vs_qspec = {}
for edge_or_node, group_id in edge_or_node_to_group_id.items():
target_edges = [edge_or_node]
if isinstance(edge_or_node, torch.fx.Node):
target_edges = []
for user in edge_or_node.users:
target_edges.append((edge_or_node, user))
group_id_vs_edges[group_id].update(target_edges)
# All qspecs should be aligned after the _get_edge_or_node_to_group_id call
group_id_vs_qspec[group_id] = _unwrap_shared_qspec_safe(
edge_or_node_to_qspec[edge_or_node], edge_or_node_to_qspec
)

q_setup = SingleConfigQuantizerSetup()
for from_n, to_nodes in q_map.items():
to_n = to_nodes[0]
qspec = edge_or_node_to_qspec[(from_n, to_n)]
for group_id, edges in group_id_vs_edges.items():
qspec = group_id_vs_qspec[group_id]
if qspec is None:
continue
if isinstance(qspec, QuantizationSpec):
if qspec.qscheme in [torch.per_channel_affine, torch.per_channel_symmetric]:
per_channel = True
elif qspec.qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
per_channel = False
else:
msg = f"Unknown qscheme: {qspec.qscheme}"
raise nncf.InternalError(msg)
signed = qspec.dtype is torch.int8
mode = (
QuantizationMode.SYMMETRIC
if qspec.qscheme in [torch.per_channel_symmetric, torch.per_tensor_symmetric]
else QuantizationMode.ASYMMETRIC
)
qconfig = QuantizerConfig(mode=mode, signedness_to_force=signed, per_channel=per_channel)

qps = TorchAOQuantizerAdapter._get_quantization_points(from_n, to_nodes, annotated_model, qconfig)
for qp in qps:
q_setup.add_independent_quantization_point(qp)

elif isinstance(qspec, SharedQuantizationSpec):
# TODO(dlyakhov): Support SharedQuantizationSpec
nncf_logger.warning(
f"SharedQuantizationSpec is not supported yet; edges {from_n} -> {to_nodes} won't be quantized."
)
else:
if not isinstance(qspec, QuantizationSpec):
msg = f"Unknown torch.ao quantization spec: {qspec}"
raise nncf.InternalError(msg)

if qspec.qscheme in [torch.per_channel_affine, torch.per_channel_symmetric]:
per_channel = True
elif qspec.qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
per_channel = False
else:
msg = f"Unknown qscheme: {qspec.qscheme}"
raise nncf.InternalError(msg)

signed = qspec.dtype is torch.int8
mode = (
QuantizationMode.SYMMETRIC
if qspec.qscheme in [torch.per_channel_symmetric, torch.per_tensor_symmetric]
else QuantizationMode.ASYMMETRIC
)
narrow_range = qspec.quant_min % 2 != 0
qconfig = QuantizerConfig(
mode=mode, signedness_to_force=signed, per_channel=per_channel, narrow_range=narrow_range
)

joined_edges = defaultdict(list)
for edge in edges:
joined_edges[edge[0]].append(edge[1])

qps = []
for from_node, to_nodes in joined_edges.items():
qps.extend(TorchAOQuantizerAdapter._get_quantization_points(from_node, to_nodes, annotated, qconfig))
qp_ids = []
for qp in qps:
qp_ids.append(q_setup.add_independent_quantization_point(qp))
if len(qp_ids) > 1:
q_setup.register_unified_scale_group(qp_ids)

return q_setup


def _get_edge_or_node_to_qspec(
model: torch.fx.GraphModule,
) -> Dict[EdgeOrNode, QuantizationSpecBase]:
def _unwrap_shared_qspec_safe(qspec: QuantizationSpec, edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpec]):
"""
Get a map from EdgeOrNode to quantization spec based on annotations on the nodes.
Iteratively unwraps a given SharedQuantizationSpec to retrieve its actual QuantizationSpec.
It detects cyclic dependencies and enforces a maximum depth limit to prevent infinite recursion.

:param model: torch.fx.GraphModule instance.
:return: A map from EdgeOrNode to quantization spec based on annotations on the nodes.
:param qspec: The quantization specification to unwrap.
:param edge_or_node_to_qspec: A dictionary mapping EdgeOrNode instances to their respective QuantizationSpec.
:return: The resolved QuantizationSpec.
"""
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase] = {}
for n in model.graph.nodes:
if hasattr(n, "meta") and "quantization_annotation" in n.meta:
qa = n.meta["quantization_annotation"]
for input_to_n, qspec in qa.input_qspec_map.items():
input_edge = (input_to_n, n)
edge_or_node_to_qspec[input_edge] = qspec
if qa.output_qspec is not None:
output_node = n
qspec = qa.output_qspec
edge_or_node_to_qspec[output_node] = qspec
return edge_or_node_to_qspec
MAX_DEPTH = 1000
i = 0
visited = []
while i < MAX_DEPTH and isinstance(qspec, SharedQuantizationSpec):
if qspec.edge_or_node in visited:
msg = f"A cycled dependency of the quantization spec is detected {visited + [qspec.edge_or_node]}"
raise RuntimeError(msg)
visited.append(qspec.edge_or_node)
qspec = edge_or_node_to_qspec[qspec.edge_or_node]
i += 1
if i == MAX_DEPTH:
msg = f"Shared qspecs referenced to each other more than the limit: {MAX_DEPTH}"
raise RuntimeError(msg)
return qspec
32 changes: 13 additions & 19 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,27 +383,21 @@ def insert_one_qdq(model: torch.fx.GraphModule, target_point: PTTargetPoint, qua

# use the same qparams from quantize op
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
if target_point.target_type == TargetType.OPERATOR_POST_HOOK:
user_dq_nodes = []
with graph.inserting_after(quantized_node):
for user in target_node.users:

with graph.inserting_after(quantized_node):
dq_node = graph.call_function(dequantize_op, tuple(dq_inputs), {})
dq_node.meta["val"] = copy(meta_val)
if target_point.target_type == TargetType.OPERATOR_POST_HOOK:
for user in list(target_node.users):
if user is quantized_node:
continue
dq_node = graph.call_function(dequantize_op, tuple(dq_inputs), {})
dq_node.meta["val"] = copy(meta_val)
user_dq_nodes.append((user, dq_node))

for user, dq_node in user_dq_nodes:
user.replace_input_with(target_node, dq_node)
elif target_point.target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]:
with graph.inserting_after(quantized_node):
dq_node = graph.call_function(dequantize_op, tuple(dq_inputs), {})
dq_node.meta["val"] = copy(meta_val)

target_node.replace_input_with(input_node, dq_node)
else:
msg = f"Unexpected target type: {target_point.target_type}"
raise nncf.InternalError(msg)
user.replace_input_with(target_node, dq_node)

elif target_point.target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]:
target_node.replace_input_with(input_node, dq_node)
else:
msg = f"Unexpected target type: {target_point.target_type}"
raise nncf.InternalError(msg)


def _insert_call_module(
Expand Down
Loading