11
11
12
12
13
13
from collections import defaultdict
14
- from typing import Dict , List , Tuple , Union
14
+ from typing import Any , Dict , List , Tuple , Union
15
15
16
16
import torch
17
17
import torch .fx
18
+ from torch .ao .quantization .pt2e .prepare import _get_edge_or_node_to_group_id
19
+ from torch .ao .quantization .pt2e .prepare import _get_edge_or_node_to_qspec
18
20
from torch .ao .quantization .quantizer import Quantizer as TorchAOQuantizer
19
21
from torch .ao .quantization .quantizer .quantizer import QuantizationSpec
20
- from torch .ao .quantization .quantizer .quantizer import QuantizationSpecBase
21
22
from torch .ao .quantization .quantizer .quantizer import SharedQuantizationSpec
22
23
23
24
import nncf
24
25
from nncf .common .graph .graph import NNCFGraph
25
- from nncf .common .logging import nncf_logger
26
26
from nncf .common .quantization .quantizer_setup import ActivationQuantizationInsertionPoint
27
27
from nncf .common .quantization .quantizer_setup import QuantizationPointBase
28
28
from nncf .common .quantization .quantizer_setup import SingleConfigQuantizationPoint
@@ -73,6 +73,15 @@ def _get_quantization_points(
73
73
annotated_model : torch .fx .GraphModule ,
74
74
qconfig : QuantizerConfig ,
75
75
) -> List [QuantizationPointBase ]:
76
+ """
77
+ Creates quantization points based on the nodes and edges.
78
+
79
+ :param from_node: The originating node in the computation graph.
80
+ :param to_nodes: The list of destination nodes of the from_node.
81
+ :param annotated_model: The torch.fx.GraphModule instance.
82
+ :param qconfig: The torch.ao quantization configuration.
83
+ :return: A list of NNCF quantization points.
84
+ """
76
85
to_n = to_nodes [0 ]
77
86
if from_node .op == "get_attr" :
78
87
_ , metatype = GraphConverter .get_node_type_and_metatype (to_n , annotated_model )
@@ -95,78 +104,102 @@ def _get_quantization_points(
95
104
return qps
96
105
97
106
@staticmethod
98
- def _get_node_args (node : torch .fx .Node ):
107
+ def _get_node_args (node : torch .fx .Node ) -> Tuple [Any , ...]:
108
+ """
109
+ Correctly retrieves arguments of the given node.
110
+
111
+ :param node: The given node.
112
+ :return: The arguments of the given node.
113
+ """
99
114
if node .target == torch .ops .aten .cat .default :
100
115
return node .args [0 ]
101
116
return node .args
102
117
103
118
@staticmethod
104
- def get_quantizer_config_from_annotated_model (annotated_model : torch .fx .GraphModule ) -> SingleConfigQuantizerSetup :
105
- edge_or_node_to_qspec = _get_edge_or_node_to_qspec (annotated_model )
106
-
107
- q_map = defaultdict (list )
108
- for edge , qspec in edge_or_node_to_qspec .items ():
109
- if not isinstance (edge , tuple ):
110
- continue
111
- from_n , to_n = edge
112
- q_map [from_n ].append (to_n )
119
+ def get_quantizer_config_from_annotated_model (annotated : torch .fx .GraphModule ) -> SingleConfigQuantizerSetup :
120
+ edge_or_node_to_qspec = _get_edge_or_node_to_qspec (annotated )
121
+ # Node means all output edges should be quantized.
122
+ # Edge means only one edge should be quantized.
123
+ edge_or_node_to_group_id = _get_edge_or_node_to_group_id (edge_or_node_to_qspec )
124
+
125
+ group_id_vs_edges = defaultdict (set )
126
+ group_id_vs_qspec = {}
127
+ for edge_or_node , group_id in edge_or_node_to_group_id .items ():
128
+ target_edges = [edge_or_node ]
129
+ if isinstance (edge_or_node , torch .fx .Node ):
130
+ target_edges = []
131
+ for user in edge_or_node .users :
132
+ target_edges .append ((edge_or_node , user ))
133
+ group_id_vs_edges [group_id ].update (target_edges )
134
+ # All qspecs should be aligned after the _get_edge_or_node_to_group_id call
135
+ group_id_vs_qspec [group_id ] = _unwrap_shared_qspec_safe (
136
+ edge_or_node_to_qspec [edge_or_node ], edge_or_node_to_qspec
137
+ )
113
138
114
139
q_setup = SingleConfigQuantizerSetup ()
115
- for from_n , to_nodes in q_map .items ():
116
- to_n = to_nodes [0 ]
117
- qspec = edge_or_node_to_qspec [(from_n , to_n )]
140
+ for group_id , edges in group_id_vs_edges .items ():
141
+ qspec = group_id_vs_qspec [group_id ]
118
142
if qspec is None :
119
143
continue
120
- if isinstance (qspec , QuantizationSpec ):
121
- if qspec .qscheme in [torch .per_channel_affine , torch .per_channel_symmetric ]:
122
- per_channel = True
123
- elif qspec .qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
124
- per_channel = False
125
- else :
126
- msg = f"Unknown qscheme: { qspec .qscheme } "
127
- raise nncf .InternalError (msg )
128
- signed = qspec .dtype is torch .int8
129
- mode = (
130
- QuantizationMode .SYMMETRIC
131
- if qspec .qscheme in [torch .per_channel_symmetric , torch .per_tensor_symmetric ]
132
- else QuantizationMode .ASYMMETRIC
133
- )
134
- qconfig = QuantizerConfig (mode = mode , signedness_to_force = signed , per_channel = per_channel )
135
-
136
- qps = TorchAOQuantizerAdapter ._get_quantization_points (from_n , to_nodes , annotated_model , qconfig )
137
- for qp in qps :
138
- q_setup .add_independent_quantization_point (qp )
139
-
140
- elif isinstance (qspec , SharedQuantizationSpec ):
141
- # TODO(dlyakhov): Support SharedQuantizationSpec
142
- nncf_logger .warning (
143
- f"SharedQuantizationSpec is not supported yet; edges { from_n } -> { to_nodes } won't be quantized."
144
- )
145
- else :
144
+ if not isinstance (qspec , QuantizationSpec ):
146
145
msg = f"Unknown torch.ao quantization spec: { qspec } "
147
146
raise nncf .InternalError (msg )
148
147
148
+ if qspec .qscheme in [torch .per_channel_affine , torch .per_channel_symmetric ]:
149
+ per_channel = True
150
+ elif qspec .qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
151
+ per_channel = False
152
+ else :
153
+ msg = f"Unknown qscheme: { qspec .qscheme } "
154
+ raise nncf .InternalError (msg )
155
+
156
+ signed = qspec .dtype is torch .int8
157
+ mode = (
158
+ QuantizationMode .SYMMETRIC
159
+ if qspec .qscheme in [torch .per_channel_symmetric , torch .per_tensor_symmetric ]
160
+ else QuantizationMode .ASYMMETRIC
161
+ )
162
+ narrow_range = qspec .quant_min % 2 != 0
163
+ qconfig = QuantizerConfig (
164
+ mode = mode , signedness_to_force = signed , per_channel = per_channel , narrow_range = narrow_range
165
+ )
166
+
167
+ joined_edges = defaultdict (list )
168
+ for edge in edges :
169
+ joined_edges [edge [0 ]].append (edge [1 ])
170
+
171
+ qps = []
172
+ for from_node , to_nodes in joined_edges .items ():
173
+ qps .extend (TorchAOQuantizerAdapter ._get_quantization_points (from_node , to_nodes , annotated , qconfig ))
174
+ qp_ids = []
175
+ for qp in qps :
176
+ qp_ids .append (q_setup .add_independent_quantization_point (qp ))
177
+ if len (qp_ids ) > 1 :
178
+ q_setup .register_unified_scale_group (qp_ids )
179
+
149
180
return q_setup
150
181
151
182
152
- def _get_edge_or_node_to_qspec (
153
- model : torch .fx .GraphModule ,
154
- ) -> Dict [EdgeOrNode , QuantizationSpecBase ]:
183
+ def _unwrap_shared_qspec_safe (qspec : QuantizationSpec , edge_or_node_to_qspec : Dict [EdgeOrNode , QuantizationSpec ]):
155
184
"""
156
- Get a map from EdgeOrNode to quantization spec based on annotations on the nodes.
185
+ Iteratively unwraps a given SharedQuantizationSpec to retrieve its actual QuantizationSpec.
186
+ It detects cyclic dependencies and enforces a maximum depth limit to prevent infinite recursion.
157
187
158
- :param model: torch.fx.GraphModule instance.
159
- :return: A map from EdgeOrNode to quantization spec based on annotations on the nodes.
188
+ :param qspec: The quantization specification to unwrap.
189
+ :param edge_or_node_to_qspec: A dictionary mapping EdgeOrNode instances to their respective QuantizationSpec.
190
+ :return: The resolved QuantizationSpec.
160
191
"""
161
- edge_or_node_to_qspec : Dict [EdgeOrNode , QuantizationSpecBase ] = {}
162
- for n in model .graph .nodes :
163
- if hasattr (n , "meta" ) and "quantization_annotation" in n .meta :
164
- qa = n .meta ["quantization_annotation" ]
165
- for input_to_n , qspec in qa .input_qspec_map .items ():
166
- input_edge = (input_to_n , n )
167
- edge_or_node_to_qspec [input_edge ] = qspec
168
- if qa .output_qspec is not None :
169
- output_node = n
170
- qspec = qa .output_qspec
171
- edge_or_node_to_qspec [output_node ] = qspec
172
- return edge_or_node_to_qspec
192
+ MAX_DEPTH = 1000
193
+ i = 0
194
+ visited = []
195
+ while i < MAX_DEPTH and isinstance (qspec , SharedQuantizationSpec ):
196
+ if qspec .edge_or_node in visited :
197
+ msg = f"A cycled dependency of the quantization spec is detected { visited + [qspec .edge_or_node ]} "
198
+ raise RuntimeError (msg )
199
+ visited .append (qspec .edge_or_node )
200
+ qspec = edge_or_node_to_qspec [qspec .edge_or_node ]
201
+ i += 1
202
+ if i == MAX_DEPTH :
203
+ msg = f"Shared qspecs referenced to each other more than the limit: { MAX_DEPTH } "
204
+ raise RuntimeError (msg )
205
+ return qspec
0 commit comments