From ed3b1f4922ef352520720d166d82f074769636f9 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 15 Jan 2026 18:34:01 +0100 Subject: [PATCH] WIP --- src/nncf/common/graph/graph.py | 93 +++++++-------- src/nncf/common/insertion_point_graph.py | 14 +-- .../sparsify_activations/torch_backend.py | 8 +- src/nncf/openvino/graph/nncf_graph_builder.py | 22 ++-- .../algorithms/smooth_quant/algorithm.py | 10 +- .../weight_compression/torch_backend.py | 22 ++-- .../weight_compression/torch_fx_backend.py | 22 +++- src/nncf/quantization/passes.py | 27 ++--- .../nncf_graph/nncf_graph_builder.py | 18 +-- src/nncf/torch/graph/graph.py | 29 +---- src/nncf/torch/model_graph_manager.py | 14 +-- tests/common/graph/test_nncf_graph.py | 108 ++++++++++++++---- tests/cross_fw/shared/nx_graph.py | 2 +- tests/cross_fw/test_templates/models.py | 34 +++--- .../native/test_nncf_graph_builder.py | 1 - tests/torch/fx/test_quantizer.py | 1 - tests/torch/fx/test_weights_compression.py | 2 +- tests/torch/test_graph_analysis.py | 1 - tests/torch/utils.py | 3 - 19 files changed, 230 insertions(+), 201 deletions(-) diff --git a/src/nncf/common/graph/graph.py b/src/nncf/common/graph/graph.py index d8b1a86a8ae..730ed77cc96 100644 --- a/src/nncf/common/graph/graph.py +++ b/src/nncf/common/graph/graph.py @@ -143,7 +143,6 @@ def __init__( output_port_id: int, tensor_shape: list[int], dtype: Dtype, - parallel_input_port_ids: list[int], ) -> None: """ :param from_node: An NNCFNode that sources the directed edge. @@ -159,7 +158,6 @@ def __init__( self.output_port_id = output_port_id self.tensor_shape: tuple[int, ...] = tuple(tensor_shape) self.dtype = dtype - self.parallel_input_port_ids = parallel_input_port_ids def __str__(self) -> str: return f"{self.from_node}:{self.output_port_id} -> {self.tensor_shape} -> {self.to_node}:{self.input_port_id}" @@ -173,7 +171,6 @@ def __hash__(self) -> int: self.output_port_id, tuple(self.tensor_shape), self.dtype, - tuple(self.parallel_input_port_ids), ) ) @@ -201,10 +198,9 @@ class NNCFGraph: INPUT_PORT_ID_EDGE_ATTR = "input_port_id" OUTPUT_PORT_ID_EDGE_ATTR = "output_port_id" DTYPE_EDGE_ATTR = "dtype" - PARALLEL_INPUT_PORT_IDS_ATTR = "parallel_input_ports" def __init__(self) -> None: - self._nx_graph = nx.DiGraph() + self._nx_graph = nx.MultiDiGraph() self._node_id_to_key_dict: dict[int, str] = {} self._nodes: dict[str, NNCFNode] = {} self._input_nncf_nodes: dict[int, NNCFNode] = {} @@ -353,7 +349,7 @@ def get_input_edges(self, node: NNCFNode) -> list[NNCFGraphEdge]: input_nodes = self.get_previous_nodes(node) edges = [] for from_node in input_nodes: - edges.extend(self._get_edges(from_node, node)) + edges.extend(self.get_edges(from_node, node)) return sorted(edges, key=lambda x: x.input_port_id) def get_input_edge_by_port_id(self, node: NNCFNode, port_id: int) -> NNCFGraphEdge: @@ -385,7 +381,7 @@ def get_output_edges(self, node: NNCFNode) -> list[NNCFGraphEdge]: output_nodes = self.get_next_nodes(node) edges = [] for to_node in output_nodes: - edges.extend(self._get_edges(node, to_node)) + edges.extend(self.get_edges(node, to_node)) return sorted(edges, key=lambda x: x.output_port_id) def get_output_edges_by_port_id(self, node: NNCFNode, port_id: int) -> list[NNCFGraphEdge]: @@ -400,26 +396,6 @@ def get_output_edges_by_port_id(self, node: NNCFNode, port_id: int) -> list[NNCF """ return [e for e in self.get_output_edges(node) if e.output_port_id == port_id] - def _get_edges(self, from_node: NNCFNode, to_node: NNCFNode) -> list[NNCFGraphEdge]: - edges = [] - edge = self.get_edge(from_node, to_node) - parallel_input_port_ids = edge.parallel_input_port_ids - edge.parallel_input_port_ids = [] - edges.append(edge) - for input_port_id in parallel_input_port_ids: - edges.append( - NNCFGraphEdge( - from_node=edge.from_node, - to_node=edge.to_node, - input_port_id=input_port_id, - output_port_id=edge.output_port_id, - tensor_shape=list(edge.tensor_shape), - dtype=edge.dtype, - parallel_input_port_ids=[], - ) - ) - return edges - def traverse_graph( self, curr_node: NNCFNode, @@ -552,7 +528,6 @@ def add_edge_between_nncf_nodes( input_port_id: int, output_port_id: int, dtype: Dtype, - parallel_input_port_ids: Optional[list[int]] = None, ) -> None: """ Adds a directed edge between two `NNCFNode`s that are already present in the graph. @@ -565,7 +540,6 @@ def add_edge_between_nncf_nodes( :param output_port_id: Specifies the index among the possible outputs of the `from_node_id` node' that this tensor should correspond to. :param dtype: The data type of the tensor. - :param parallel_input_port_ids: Input ports for parallel edges, if any should be present for this edge. """ from_node_key = self._node_id_to_key_dict[from_node_id] to_node_key = self._node_id_to_key_dict[to_node_id] @@ -581,6 +555,19 @@ def add_edge_between_nncf_nodes( if to_node_id in self._input_nncf_nodes: err_reason = "cannot add edges *to* input nodes" + exist_edges = self._nx_graph.get_edge_data(from_node_key, to_node_key) + if exist_edges is not None: + for edge in exist_edges.values(): + if ( + edge[NNCFGraph.INPUT_PORT_ID_EDGE_ATTR] == input_port_id + and edge[NNCFGraph.OUTPUT_PORT_ID_EDGE_ATTR] == output_port_id + ): + err_reason = ( + "two edges have the same pair of port ids:" + f" input_port_id({input_port_id}) output_port_id({output_port_id})" + ) + break + if err_reason is not None: msg = f"Cannot add edge from {from_node_key} to {to_node_key} - {err_reason}!" raise ValueError(msg) @@ -590,7 +577,6 @@ def add_edge_between_nncf_nodes( NNCFGraph.INPUT_PORT_ID_EDGE_ATTR: input_port_id, NNCFGraph.OUTPUT_PORT_ID_EDGE_ATTR: output_port_id, NNCFGraph.DTYPE_EDGE_ATTR: dtype, - NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR: [] if parallel_input_port_ids is None else parallel_input_port_ids, } self._nx_graph.add_edge(from_node_key, to_node_key, **attrs) @@ -621,7 +607,7 @@ def get_graph_for_structure_analysis(self, extended: bool = False) -> nx.DiGraph :param extended: whether the graph edges should have attributes: shape of the tensor and tensor primitive type. :return: An nx.DiGraph to be used for structure analysis """ - out_graph = nx.DiGraph() + out_graph = nx.MultiDiGraph() for node_name, node in self._nx_graph.nodes.items(): attrs_node = {"id": str(node[NNCFNode.ID_NODE_ATTR]), "type": node[NNCFNode.NODE_TYPE_ATTR]} for attr in ["color", "label", "style"]: @@ -630,13 +616,10 @@ def get_graph_for_structure_analysis(self, extended: bool = False) -> nx.DiGraph out_graph.add_node(node_name, **attrs_node) - for u, v in self._nx_graph.edges: - edge = self._nx_graph.edges[u, v] + for u, v, k in self._nx_graph.edges: + edge = self._nx_graph.edges[u, v, k] attrs_edge = {} label = {} - if edge[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]: - label["parallel_input_port_ids"] = edge[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR] - if extended: if edge[NNCFGraph.DTYPE_EDGE_ATTR] is Dtype.INTEGER: attrs_edge["style"] = "dashed" @@ -735,7 +718,6 @@ def get_nncf_graph_pattern_io(self, match: list[str]) -> NNCFGraphPatternIO: output_port_id=data[NNCFGraph.OUTPUT_PORT_ID_EDGE_ATTR], tensor_shape=data[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR], dtype=data[NNCFGraph.DTYPE_EDGE_ATTR], - parallel_input_port_ids=data[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR], ) if from_node_key in match: output_nncf_edges.append(nncf_edge) @@ -747,36 +729,39 @@ def get_nncf_graph_pattern_io(self, match: list[str]) -> NNCFGraphPatternIO: return NNCFGraphPatternIO(input_nncf_edges, output_nncf_edges) - def get_nx_edge(self, node_u: NNCFNode, node_v: NNCFNode) -> OutEdgeView: + def get_nx_edges(self, node_u: NNCFNode, node_v: NNCFNode) -> OutEdgeView: nx_node_u = self._nx_graph.nodes[self._node_id_to_key_dict[node_u.node_id]] nx_node_v = self._nx_graph.nodes[self._node_id_to_key_dict[node_v.node_id]] - return self._nx_graph.edges[nx_node_u["key"], nx_node_v["key"]] + return self._nx_graph.get_edge_data(nx_node_u["key"], nx_node_v["key"]) def get_nodes_count(self) -> int: return int(self._nx_graph.number_of_nodes()) - def get_edge(self, from_node: NNCFNode, to_node: NNCFNode) -> NNCFGraphEdge: + def get_edges(self, from_node: NNCFNode, to_node: NNCFNode) -> list[NNCFGraphEdge]: """ - Returns an NNCFGraphEdge object that corresponds to an edge connecting two given NNCFNodes in this + Returns a list of NNCFGraphEdge objects that corresponds to edges connecting two given NNCFNodes in this graph. + :param from_node: The NNCFNode in this graph that sources the edge. :param to_node: The NNCFNode in this graph that sinks the edge. - :return: The NNCFGraphEdge object representing the edge between `from_node` and `to_node`. - """ - data = self.get_nx_edge(from_node, to_node) - return NNCFGraphEdge( - from_node, - to_node, - data[NNCFGraph.INPUT_PORT_ID_EDGE_ATTR], - data[NNCFGraph.OUTPUT_PORT_ID_EDGE_ATTR], - data[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR], - data[NNCFGraph.DTYPE_EDGE_ATTR], - data[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR], - ) + :return: The list of NNCFGraphEdge objects representing the edges between `from_node` and `to_node`. + """ + nx_edges = self.get_nx_edges(from_node, to_node) + return [ + NNCFGraphEdge( + from_node, + to_node, + data[NNCFGraph.INPUT_PORT_ID_EDGE_ATTR], + data[NNCFGraph.OUTPUT_PORT_ID_EDGE_ATTR], + data[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR], + data[NNCFGraph.DTYPE_EDGE_ATTR], + ) + for data in nx_edges.values() + ] def get_all_edges(self) -> Generator[NNCFGraphEdge, None, None]: for nx_edge in self._nx_graph.in_edges: - yield self.get_edge(self.get_node_by_key(nx_edge[0]), self.get_node_by_key(nx_edge[1])) + yield from self.get_edges(self.get_node_by_key(nx_edge[0]), self.get_node_by_key(nx_edge[1])) def remove_nodes_from(self, nodes: Iterable[NNCFNode]) -> None: """ diff --git a/src/nncf/common/insertion_point_graph.py b/src/nncf/common/insertion_point_graph.py index 6048e991d7f..c4b14376108 100644 --- a/src/nncf/common/insertion_point_graph.py +++ b/src/nncf/common/insertion_point_graph.py @@ -51,9 +51,7 @@ class InsertionPointGraph(nx.DiGraph): # type: ignore """ This graph is built from the NNCFGraph representation of the model control flow graph and adds ephemeral "insertion point nodes" into the NNCF model graph representation corresponding to operator pre- and - post-hooks. Module pre-op and post-op insertion points are currently not reflected here, but they are - probably not required for quantizing activations, for which the quantizer propagation makes sense. - This "insertion point graph" representation is useful for quantizer propagation and for referencing + post-hooks. This "insertion point graph" representation is useful for quantizer propagation and for referencing the compression algorithm hooks to the model operations to which they are applied to. """ @@ -118,13 +116,11 @@ def __init__( for edge in self._base_nx_graph.edges: input_port_id = self._base_nx_graph.edges[edge][NNCFGraph.INPUT_PORT_ID_EDGE_ATTR] dtype = self._base_nx_graph.edges[edge][NNCFGraph.DTYPE_EDGE_ATTR] - parallel_input_port_ids = self._base_nx_graph.edges[edge][NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR] - from_node, to_node = edge + from_node, to_node, _ = edge attrs = { INPUT_PORT_ID: input_port_id, self.IS_INTEGER_PATH_EDGE_ATTR: dtype is Dtype.INTEGER, - NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR: parallel_input_port_ids, } self.add_edge(from_node, to_node, **attrs) @@ -148,8 +144,6 @@ def __init__( for edge in in_edges: input_port_id = self.edges[edge][INPUT_PORT_ID] input_port_id_vs_edge[input_port_id] = edge - for parallel_input_port_id in self.edges[edge][NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]: - input_port_id_vs_edge[parallel_input_port_id] = edge encountered_input_edges = set() for pre_hook_point in pre_hook_ips: @@ -235,8 +229,8 @@ def _get_default_pre_hook_ip_list(nncf_graph: NNCFGraph) -> list[PreHookInsertio pred_nodes = nncf_graph.get_previous_nodes(nncf_node) for pred_node in pred_nodes: - input_edge = nncf_graph.get_edge(pred_node, nncf_node) - input_port_ids = [input_edge.input_port_id] + input_edge.parallel_input_port_ids + input_edges = nncf_graph.get_edges(pred_node, nncf_node) + input_port_ids = [edge.input_port_id for edge in input_edges] node_name = nncf_node.node_name for input_port_id in input_port_ids: allowed_pre_hook_insertion_points.append(PreHookInsertionPoint(node_name, input_port_id)) diff --git a/src/nncf/experimental/torch/sparsify_activations/torch_backend.py b/src/nncf/experimental/torch/sparsify_activations/torch_backend.py index 52b405ed781..f2a4a8fe868 100644 --- a/src/nncf/experimental/torch/sparsify_activations/torch_backend.py +++ b/src/nncf/experimental/torch/sparsify_activations/torch_backend.py @@ -183,10 +183,10 @@ def _get_activation_port_id(node: NNCFNode, graph: NNCFGraph) -> int: """ activation_ports = [] for prev_node in graph.get_previous_nodes(node): - edge = graph.get_edge(prev_node, node) - if prev_node.metatype in CONST_NOOP_METATYPES or edge.input_port_id in node.metatype.weight_port_ids: - continue - activation_ports.append(edge.input_port_id) + for edge in graph.get_edges(prev_node, node): + if prev_node.metatype in CONST_NOOP_METATYPES or edge.input_port_id in node.metatype.weight_port_ids: + continue + activation_ports.append(edge.input_port_id) if len(activation_ports) != 1: msg = f'Cannot find activation port for node "{node}".' raise nncf.InternalError(msg) diff --git a/src/nncf/openvino/graph/nncf_graph_builder.py b/src/nncf/openvino/graph/nncf_graph_builder.py index e196f7e0d8d..e3fef181a11 100644 --- a/src/nncf/openvino/graph/nncf_graph_builder.py +++ b/src/nncf/openvino/graph/nncf_graph_builder.py @@ -109,19 +109,15 @@ def _add_edges_to_nncf_graph(model: ov.Model, graph: NNCFGraph) -> None: output_node_id = graph.get_node_by_name(out_node.get_friendly_name()).node_id nncf_dtype = GraphConverter.convert_to_nncf_dtype(out.get_element_type()) - parallel_inputs = None - if len(inputs) > 1: - parallel_inputs = [inp.get_index() for inp in inputs[1:]] - - graph.add_edge_between_nncf_nodes( - from_node_id=in_node_id, - to_node_id=output_node_id, - tensor_shape=tensor_shape, - input_port_id=inputs[0].get_index(), - output_port_id=output_port_id, - dtype=nncf_dtype, - parallel_input_port_ids=parallel_inputs, - ) + for inp in inputs: + graph.add_edge_between_nncf_nodes( + from_node_id=in_node_id, + to_node_id=output_node_id, + tensor_shape=tensor_shape, + input_port_id=inp.get_index(), + output_port_id=output_port_id, + dtype=nncf_dtype, + ) @staticmethod def _add_nncf_node(node: ov.Node, graph: NNCFGraph) -> None: diff --git a/src/nncf/quantization/algorithms/smooth_quant/algorithm.py b/src/nncf/quantization/algorithms/smooth_quant/algorithm.py index 3aa7ef9eb4e..5caea950bce 100644 --- a/src/nncf/quantization/algorithms/smooth_quant/algorithm.py +++ b/src/nncf/quantization/algorithms/smooth_quant/algorithm.py @@ -226,10 +226,16 @@ def _group_nodes_by_source( groups = defaultdict(list) for node_to_smooth, input_act_port, shape in nodes_to_smooth: source_node = nncf_graph.get_input_edge_by_port_id(node_to_smooth, input_act_port).from_node - edge = nncf_graph.get_edge(source_node, node_to_smooth) + edges = nncf_graph.get_edges(source_node, node_to_smooth) + if len(edges) > 1: + msg = ( + f"Only one node expected between {source_node.node_name} and {node_to_smooth.node_name}," + f" but {len(edges)} found" + ) + raise nncf.InternalError(msg) # Such group_id (with node, ports, and shape as a hash) allows us to be confident # that all sensitive parameters are equal for successor nodes are equal. - group_id = (source_node, input_act_port, edge.output_port_id, shape) + group_id = (source_node, input_act_port, edges[0].output_port_id, shape) groups[group_id].append(node_to_smooth) return groups diff --git a/src/nncf/quantization/algorithms/weight_compression/torch_backend.py b/src/nncf/quantization/algorithms/weight_compression/torch_backend.py index eb142c032b4..bca44c529b2 100644 --- a/src/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/src/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -101,12 +101,12 @@ def is_node_with_weights(node: NNCFNode, graph: NNCFGraph) -> bool: ): return False for prev_node in graph.get_previous_nodes(node): - edge = graph.get_edge(prev_node, node) - if edge.input_port_id not in node.metatype.weight_port_ids: - continue - weight_node = find_const_node_in_constant_subgraph(prev_node, graph) - if weight_node is not None: - return True + for edge in graph.get_edges(prev_node, node): + if edge.input_port_id not in node.metatype.weight_port_ids: + continue + weight_node = find_const_node_in_constant_subgraph(prev_node, graph) + if weight_node is not None: + return True return False @staticmethod @@ -116,9 +116,9 @@ def get_weight_names_and_port_ids(node: NNCFNode, graph: NNCFGraph) -> list[tupl weight_node = find_const_node_in_constant_subgraph(prev_node, graph) if weight_node is None: continue - edge = graph.get_edge(prev_node, node) - if edge.input_port_id in node.metatype.weight_port_ids: - weight_port_ids.append((weight_node.layer_attributes.name, edge.input_port_id)) + for edge in graph.get_edges(prev_node, node): + if edge.input_port_id in node.metatype.weight_port_ids: + weight_port_ids.append((weight_node.layer_attributes.name, edge.input_port_id)) return weight_port_ids @staticmethod @@ -154,8 +154,8 @@ def get_activation_port_id(node: NNCFNode, graph: NNCFGraph) -> int: for prev_node in graph.get_previous_nodes(node): if prev_node.metatype in CONST_NOOP_METATYPES: continue - edge = graph.get_edge(prev_node, node) - activation_ports.append(edge.input_port_id) + for edge in graph.get_edges(prev_node, node): + activation_ports.append(edge.input_port_id) assert len(activation_ports) == 1 return activation_ports[0] diff --git a/src/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py b/src/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py index 2182c85b6bf..ef2d4cabd56 100644 --- a/src/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py +++ b/src/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py @@ -91,10 +91,18 @@ def get_weight_names_and_port_ids(node: NNCFNode, graph: NNCFGraph) -> list[tupl @staticmethod def get_reduction_axes(node_with_weight: NNCFNode, weight_port_id: int, graph: NNCFGraph) -> Optional[tuple[int]]: weight_node = get_const_node(node_with_weight, weight_port_id, graph) - edge = graph.get_edge(weight_node, graph.get_next_nodes(weight_node)[0]) + next_node = graph.get_next_nodes(weight_node)[0] + edges = graph.get_edges(weight_node, next_node) + + if len(edges) > 1: + msg = ( + f"Only one node expected between the {weight_node.node_name} and {next_node.node_name}," + f" but {len(edges)} found" + ) + raise nncf.InternalError(msg) node_with_weight_metatype = node_with_weight.metatype - ndims = len(edge.tensor_shape) + ndims = len(edges[0].tensor_shape) reduction_axes = get_weight_compression_reduction_axes(node_with_weight_metatype, weight_port_id, ndims) return tuple(reduction_axes) @@ -139,8 +147,14 @@ def get_weight_dtype( @staticmethod def get_weight_shape(node_with_weight: NNCFNode, weight_port_id: int, graph: NNCFGraph) -> tuple: weight_node = get_const_node(node_with_weight, weight_port_id, graph) - edge = graph.get_edge(weight_node, node_with_weight) - return tuple(edge.tensor_shape) + edges = graph.get_edges(weight_node, node_with_weight) + if len(edges) > 1: + msg = ( + f"Only one node expected between the {node_with_weight.node_name} and its weight," + f" but {len(edges)} found" + ) + raise nncf.InternalError(msg) + return tuple(edges[0].tensor_shape) def set_weight( self, diff --git a/src/nncf/quantization/passes.py b/src/nncf/quantization/passes.py index 0f1f8d682f6..601d6e9cab8 100644 --- a/src/nncf/quantization/passes.py +++ b/src/nncf/quantization/passes.py @@ -146,20 +146,18 @@ def remove_nodes_and_reconnect_graph( nodes_to_drop = [] for node in nncf_graph.get_nodes_by_metatypes(metatypes): - if node.metatype in metatypes: - nodes_to_drop.append(node) - - prev_nodes = nncf_graph.get_previous_nodes(node) - input_edges = nncf_graph.get_input_edges(node) - assert len(prev_nodes) == len(input_edges) == 1 - prev_node = prev_nodes[0] - input_edge = input_edges[0] - assert not input_edge.parallel_input_port_ids - - # nncf_graph.get_next_edges is not used to preserve - # parallel_input_port_ids - for output_node in nncf_graph.get_next_nodes(node): - output_edge = nncf_graph.get_edge(node, output_node) + if node.metatype not in metatypes: + continue + nodes_to_drop.append(node) + + prev_nodes = nncf_graph.get_previous_nodes(node) + input_edges = nncf_graph.get_input_edges(node) + assert len(prev_nodes) == len(input_edges) == 1 + prev_node = prev_nodes[0] + input_edge = input_edges[0] + + for output_node in nncf_graph.get_next_nodes(node): + for output_edge in nncf_graph.get_edges(node, output_node): # Connects previous node with all next nodes # to keep NNCFGraph connected. assert input_edge.dtype == output_edge.dtype @@ -171,7 +169,6 @@ def remove_nodes_and_reconnect_graph( input_port_id=output_edge.input_port_id, output_port_id=input_edge.output_port_id, dtype=input_edge.dtype, - parallel_input_port_ids=output_edge.parallel_input_port_ids, ) nncf_graph.remove_nodes_from(nodes_to_drop) return nncf_graph diff --git a/src/nncf/torch/function_hook/nncf_graph/nncf_graph_builder.py b/src/nncf/torch/function_hook/nncf_graph/nncf_graph_builder.py index 0080ccb0564..c7f31c99769 100644 --- a/src/nncf/torch/function_hook/nncf_graph/nncf_graph_builder.py +++ b/src/nncf/torch/function_hook/nncf_graph/nncf_graph_builder.py @@ -201,15 +201,15 @@ def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> PTNNCFGraph: for (s_node, t_node), list_meta in map_edges.items(): source_node = map_nx_node_to_nncf_node[s_node] target_node = map_nx_node_to_nncf_node[t_node] - nncf_graph.add_edge_between_nncf_nodes( - source_node.node_id, - target_node.node_id, - tensor_shape=list_meta[0].shape, - input_port_id=list_meta[0].input_port, - output_port_id=list_meta[0].output_port, - dtype=get_dtype(list_meta[0].dtype), - parallel_input_port_ids=[meta.input_port for meta in list_meta[1:]] if len(list_meta) > 1 else None, - ) + for meta in list_meta: + nncf_graph.add_edge_between_nncf_nodes( + source_node.node_id, + target_node.node_id, + tensor_shape=meta.shape, + input_port_id=meta.input_port, + output_port_id=meta.output_port, + dtype=get_dtype(meta.dtype), + ) return nncf_graph diff --git a/src/nncf/torch/graph/graph.py b/src/nncf/torch/graph/graph.py index 19c3053ba81..548f323be03 100644 --- a/src/nncf/torch/graph/graph.py +++ b/src/nncf/torch/graph/graph.py @@ -13,42 +13,19 @@ from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode -from nncf.common.graph import NNCFNodeName from nncf.torch.function_hook.graph.graph_utils import TensorMeta from nncf.torch.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes from nncf.torch.graph.transformations.commands import PTTargetPoint class PTNNCFGraph(NNCFGraph): - def get_output_shapes_for_node(self, node_name: NNCFNodeName) -> list[tuple]: - node = self.get_node_by_name(node_name) - node_key = self.get_node_key_by_id(node.node_id) - succs = list(self._nx_graph.successors(node_key)) - edge_list = [self._nx_graph.edges[node_key, to_node_key] for to_node_key in succs] - return [edge[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR] for edge in edge_list] - - def get_input_shapes_for_node(self, node_name: NNCFNodeName) -> dict[int, tuple]: - node = self.get_node_by_name(node_name) - node_key = self.get_node_key_by_id(node.node_id) - in_edges = list(self._nx_graph.in_edges(node_key)) - retval = {} - for in_edge in in_edges: - edge_attr_dict = self._nx_graph.edges[in_edge] - port_id = edge_attr_dict[NNCFGraph.INPUT_PORT_ID_EDGE_ATTR] - assert port_id not in retval - for p in [ - port_id, - ] + edge_attr_dict[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]: - retval[p] = edge_attr_dict[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR] - return retval - def get_input_shape_for_insertion_point(self, insertion_point: PTTargetPoint) -> tuple[int]: - target_node_name = insertion_point.target_node_name + node = self.get_node_by_name(insertion_point.target_node_name) if insertion_point.input_port_id is not None: - quantizer_input_shape = self.get_input_shapes_for_node(target_node_name)[insertion_point.input_port_id] + quantizer_input_shape = [edge.tensor_shape for edge in self.get_input_edges(node)] else: # Tailored for post-hook quantization and first output quantization only - quantizer_input_shape = self.get_output_shapes_for_node(target_node_name)[0] + quantizer_input_shape = self.get_output_edges(node)[0].tensor_shape return quantizer_input_shape def get_nodes_with_missed_input_edges(self) -> list[NNCFNode]: diff --git a/src/nncf/torch/model_graph_manager.py b/src/nncf/torch/model_graph_manager.py index 53413731199..3f193dd40cb 100644 --- a/src/nncf/torch/model_graph_manager.py +++ b/src/nncf/torch/model_graph_manager.py @@ -68,13 +68,13 @@ def get_const_node(node: NNCFNode, port_id: int, graph: NNCFGraph) -> Optional[N :return: The NNCF node providing the constant input to the specified port, or None if no such node is found. """ for prev_node in graph.get_previous_nodes(node): - edge = graph.get_edge(prev_node, node) - if edge.input_port_id == port_id: - weight_node = find_const_node_in_constant_subgraph(prev_node, graph) - if weight_node is None: - msg = "Could not find a constant node in the model graph." - raise nncf.InternalError(msg) - return weight_node + for edge in graph.get_edges(prev_node, node): + if edge.input_port_id == port_id: + weight_node = find_const_node_in_constant_subgraph(prev_node, graph) + if weight_node is None: + msg = "Could not find a constant node in the model graph." + raise nncf.InternalError(msg) + return weight_node def split_const_name(const_name: str) -> tuple[str, str]: diff --git a/tests/common/graph/test_nncf_graph.py b/tests/common/graph/test_nncf_graph.py index c3366e04bf3..c7f65529240 100644 --- a/tests/common/graph/test_nncf_graph.py +++ b/tests/common/graph/test_nncf_graph.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFGraphEdge from nncf.common.graph.layer_attributes import Dtype @@ -51,38 +53,37 @@ def test_find_matching_subgraphs(): assert match == nodes[:2] -def test_parallel_edges(): - def _get_default_nncf_graph_edge(from_node, to_node, input_port_id, output_port_id): - return NNCFGraphEdge( - from_node, - to_node, - input_port_id=input_port_id, - output_port_id=output_port_id, - parallel_input_port_ids=[], - tensor_shape=(1, 2, 3), - dtype="dummy", - ) +def _get_default_nncf_graph_edge(from_node, to_node, input_port_id, output_port_id): + return NNCFGraphEdge( + from_node, + to_node, + input_port_id=input_port_id, + output_port_id=output_port_id, + tensor_shape=(1, 2, 3), + dtype="dummy", + ) + +def test_parallel_edges(): nncf_graph = NNCFGraph() nodes = [] for node in "abc": nodes.append(nncf_graph.add_nncf_node(node, f"type_{node}", f"metatype_{node}")) - nncf_graph.add_edge_between_nncf_nodes( - nodes[0].node_id, - nodes[1].node_id, - input_port_id=0, - output_port_id=0, - parallel_input_port_ids=list(range(1, 5)), - tensor_shape=(1, 2, 3), - dtype="dummy", - ) + for input_port_id in range(5): + nncf_graph.add_edge_between_nncf_nodes( + nodes[0].node_id, + nodes[1].node_id, + input_port_id=input_port_id, + output_port_id=0, + tensor_shape=(1, 2, 3), + dtype="dummy", + ) nncf_graph.add_edge_between_nncf_nodes( nodes[0].node_id, nodes[2].node_id, input_port_id=10, output_port_id=15, - parallel_input_port_ids=[], tensor_shape=(1, 2, 3), dtype="dummy", ) @@ -107,3 +108,68 @@ def _get_default_nncf_graph_edge(from_node, to_node, input_port_id, output_port_ output_port_id=15, ) assert ordinary_edge == output_edges[-1] + + +def test_raise_error_for_dublicated_edge(): + nncf_graph = NNCFGraph() + nodes = [] + for node in "abc": + nodes.append(nncf_graph.add_nncf_node(node, f"type_{node}", f"metatype_{node}")) + + nncf_graph.add_edge_between_nncf_nodes( + nodes[0].node_id, + nodes[1].node_id, + input_port_id=0, + output_port_id=0, + tensor_shape=(1, 2, 3), + dtype="dummy", + ) + # Second edge from port 1 to port 1 - OK + nncf_graph.add_edge_between_nncf_nodes( + nodes[0].node_id, + nodes[1].node_id, + input_port_id=1, + output_port_id=1, + tensor_shape=(1, 2, 3), + dtype="dummy", + ) + with pytest.raises(ValueError): + nncf_graph.add_edge_between_nncf_nodes( + nodes[0].node_id, + nodes[1].node_id, + input_port_id=1, + output_port_id=1, + tensor_shape=(1, 2, 3), + dtype="dummy", + ) + + +def test_multi_edges(): + nncf_graph = NNCFGraph() + nodes = [] + for node in "ab": + nodes.append(nncf_graph.add_nncf_node(node, f"type_{node}", f"metatype_{node}")) + + for port_id in range(5): + nncf_graph.add_edge_between_nncf_nodes( + nodes[0].node_id, + nodes[1].node_id, + input_port_id=port_id, + output_port_id=port_id, + tensor_shape=(1, 2, 3), + dtype="dummy", + ) + + output_edges = nncf_graph.get_output_edges(nodes[0]) + input_edges = nncf_graph.get_input_edges(nodes[1]) + assert len(input_edges) == 5 + assert len(output_edges) == 5 + assert input_edges == output_edges + for port_id, edge in enumerate(input_edges): + ref_edge = _get_default_nncf_graph_edge( + nodes[0], + nodes[1], + input_port_id=port_id, + output_port_id=port_id, + ) + assert ref_edge == edge diff --git a/tests/cross_fw/shared/nx_graph.py b/tests/cross_fw/shared/nx_graph.py index db6f9fd271f..b855468d97f 100644 --- a/tests/cross_fw/shared/nx_graph.py +++ b/tests/cross_fw/shared/nx_graph.py @@ -120,7 +120,7 @@ def _build_edge_vs_attrs_dict( ) -> dict[tuple[Union[int, str], Union[int, str]], dict[str, str]]: retval = {} for edge_tuple, edge_attrs in nx_graph.edges.items(): - from_node_name, to_node_name = edge_tuple + from_node_name, to_node_name, _ = edge_tuple if id_from_attr: from_node, to_node = nx_graph.nodes[from_node_name], nx_graph.nodes[to_node_name] edge_id = int(from_node["id"]), int(to_node["id"]) diff --git a/tests/cross_fw/test_templates/models.py b/tests/cross_fw/test_templates/models.py index bb1e2978006..099228a9348 100644 --- a/tests/cross_fw/test_templates/models.py +++ b/tests/cross_fw/test_templates/models.py @@ -319,26 +319,26 @@ def __init__( dropout_2 = self.nncf_graph.get_node_by_key("4 /Dropout_3_0") output = self.nncf_graph.add_nncf_node("/Output_3_1_0", "output", OutputNoopMetatype) - self.nncf_graph.add_edge_between_nncf_nodes( - dropout_2.node_id, - output.node_id, - tensor_shape=tensor_shape, - input_port_id=1, - output_port_id=1, - dtype=Dtype.FLOAT, - parallel_input_port_ids=list(range(2, 10)), - ) - if wrong_parallel_edges: - dropout_4 = self.nncf_graph.add_nncf_node("100 /dropout", "dropout", dropout_metatype) + for input_port_id in range(1, 10): self.nncf_graph.add_edge_between_nncf_nodes( - self.nncf_graph.get_node_by_key("0 /Input_1_0").node_id, - dropout_4.node_id, - tensor_shape=[1, 1, 1, 1], - input_port_id=0, - output_port_id=0, + dropout_2.node_id, + output.node_id, + tensor_shape=tensor_shape, + input_port_id=input_port_id, + output_port_id=1, dtype=Dtype.FLOAT, - parallel_input_port_ids=list(range(1, 10)), ) + if wrong_parallel_edges: + dropout_4 = self.nncf_graph.add_nncf_node("100 /dropout", "dropout", dropout_metatype) + for input_port_id in range(10): + self.nncf_graph.add_edge_between_nncf_nodes( + self.nncf_graph.get_node_by_key("0 /Input_1_0").node_id, + dropout_4.node_id, + tensor_shape=[1, 1, 1, 1], + input_port_id=input_port_id, + output_port_id=0, + dtype=Dtype.FLOAT, + ) class NNCFGraphToTestConstantFiltering: diff --git a/tests/openvino/native/test_nncf_graph_builder.py b/tests/openvino/native/test_nncf_graph_builder.py index cc42f6e39d0..cbc25900d83 100644 --- a/tests/openvino/native/test_nncf_graph_builder.py +++ b/tests/openvino/native/test_nncf_graph_builder.py @@ -75,7 +75,6 @@ def _get_default_nncf_graph_edge(from_node, to_node, input_port_id, output_port_ output_port_id=output_port_id, tensor_shape=[1, 3, 3], dtype=Dtype.FLOAT, - parallel_input_port_ids=[], ) model = ParallelEdgesModel().ov_model diff --git a/tests/torch/fx/test_quantizer.py b/tests/torch/fx/test_quantizer.py index 90734f60b0f..9c58b3dd153 100644 --- a/tests/torch/fx/test_quantizer.py +++ b/tests/torch/fx/test_quantizer.py @@ -313,7 +313,6 @@ def _normalize_nncf_graph(nncf_graph: NNCFGraph, fx_graph: torch.fx.Graph): input_port_id=edge.input_port_id, output_port_id=edge.output_port_id, dtype=dtype, - parallel_input_port_ids=edge.parallel_input_port_ids, ) return norm_nncf_graph diff --git a/tests/torch/fx/test_weights_compression.py b/tests/torch/fx/test_weights_compression.py index ac5b96756b0..51cce777c41 100644 --- a/tests/torch/fx/test_weights_compression.py +++ b/tests/torch/fx/test_weights_compression.py @@ -125,7 +125,7 @@ def test_compress_weights_graph_edge(mode): for node in nncf_graph.get_all_nodes(): if "weights_decompressor" in node.node_name and node.node_type == "call_module": decompressor_node_edge = nncf_graph.get_input_edges(node)[0] - decompressor_constant_edge = nncf_graph.get_edge(node, nncf_graph.get_next_nodes(node)[0]) + decompressor_constant_edge = nncf_graph.get_edges(node, nncf_graph.get_next_nodes(node)[0])[0] assert decompressor_node_edge.tensor_shape == decompressor_constant_edge.tensor_shape diff --git a/tests/torch/test_graph_analysis.py b/tests/torch/test_graph_analysis.py index 3e495081075..39a300518a5 100644 --- a/tests/torch/test_graph_analysis.py +++ b/tests/torch/test_graph_analysis.py @@ -57,7 +57,6 @@ def make_mock_edge( output_port_id=output_port_id, tensor_shape=[1, 1, 1, 1], dtype=Dtype.FLOAT, - parallel_input_port_ids=[], ) def get_node(name: NNCFNodeName): diff --git a/tests/torch/utils.py b/tests/torch/utils.py index 8f8f9b89959..270a8471839 100644 --- a/tests/torch/utils.py +++ b/tests/torch/utils.py @@ -67,7 +67,6 @@ def to_comparable_nx_graph(graph: NNCFGraph) -> nx.DiGraph: - shape - out_port_id - in_port_id - - parallel_input_port_ids (if exists) :param graph: NNCFGraph to convert. :return: Graph in nx.DiGraph. @@ -88,8 +87,6 @@ def to_comparable_nx_graph(graph: NNCFGraph) -> nx.DiGraph: "out_port_id": edge.output_port_id, "in_port_id": edge.input_port_id, } - if edge.parallel_input_port_ids: - attrs_edge["parallel_input_port_ids"] = edge.parallel_input_port_ids out_graph.add_edge(_quote_str(edge.from_node.node_name), _quote_str(edge.to_node.node_name), **attrs_edge) return out_graph