Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 39 additions & 54 deletions src/nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def __init__(
output_port_id: int,
tensor_shape: list[int],
dtype: Dtype,
parallel_input_port_ids: list[int],
) -> None:
"""
:param from_node: An NNCFNode that sources the directed edge.
Expand All @@ -159,7 +158,6 @@ def __init__(
self.output_port_id = output_port_id
self.tensor_shape: tuple[int, ...] = tuple(tensor_shape)
self.dtype = dtype
self.parallel_input_port_ids = parallel_input_port_ids

def __str__(self) -> str:
return f"{self.from_node}:{self.output_port_id} -> {self.tensor_shape} -> {self.to_node}:{self.input_port_id}"
Expand All @@ -173,7 +171,6 @@ def __hash__(self) -> int:
self.output_port_id,
tuple(self.tensor_shape),
self.dtype,
tuple(self.parallel_input_port_ids),
)
)

Expand Down Expand Up @@ -201,10 +198,9 @@ class NNCFGraph:
INPUT_PORT_ID_EDGE_ATTR = "input_port_id"
OUTPUT_PORT_ID_EDGE_ATTR = "output_port_id"
DTYPE_EDGE_ATTR = "dtype"
PARALLEL_INPUT_PORT_IDS_ATTR = "parallel_input_ports"

def __init__(self) -> None:
self._nx_graph = nx.DiGraph()
self._nx_graph = nx.MultiDiGraph()
self._node_id_to_key_dict: dict[int, str] = {}
self._nodes: dict[str, NNCFNode] = {}
self._input_nncf_nodes: dict[int, NNCFNode] = {}
Expand Down Expand Up @@ -353,7 +349,7 @@ def get_input_edges(self, node: NNCFNode) -> list[NNCFGraphEdge]:
input_nodes = self.get_previous_nodes(node)
edges = []
for from_node in input_nodes:
edges.extend(self._get_edges(from_node, node))
edges.extend(self.get_edges(from_node, node))
return sorted(edges, key=lambda x: x.input_port_id)

def get_input_edge_by_port_id(self, node: NNCFNode, port_id: int) -> NNCFGraphEdge:
Expand Down Expand Up @@ -385,7 +381,7 @@ def get_output_edges(self, node: NNCFNode) -> list[NNCFGraphEdge]:
output_nodes = self.get_next_nodes(node)
edges = []
for to_node in output_nodes:
edges.extend(self._get_edges(node, to_node))
edges.extend(self.get_edges(node, to_node))
return sorted(edges, key=lambda x: x.output_port_id)

def get_output_edges_by_port_id(self, node: NNCFNode, port_id: int) -> list[NNCFGraphEdge]:
Expand All @@ -400,26 +396,6 @@ def get_output_edges_by_port_id(self, node: NNCFNode, port_id: int) -> list[NNCF
"""
return [e for e in self.get_output_edges(node) if e.output_port_id == port_id]

def _get_edges(self, from_node: NNCFNode, to_node: NNCFNode) -> list[NNCFGraphEdge]:
edges = []
edge = self.get_edge(from_node, to_node)
parallel_input_port_ids = edge.parallel_input_port_ids
edge.parallel_input_port_ids = []
edges.append(edge)
for input_port_id in parallel_input_port_ids:
edges.append(
NNCFGraphEdge(
from_node=edge.from_node,
to_node=edge.to_node,
input_port_id=input_port_id,
output_port_id=edge.output_port_id,
tensor_shape=list(edge.tensor_shape),
dtype=edge.dtype,
parallel_input_port_ids=[],
)
)
return edges

def traverse_graph(
self,
curr_node: NNCFNode,
Expand Down Expand Up @@ -552,7 +528,6 @@ def add_edge_between_nncf_nodes(
input_port_id: int,
output_port_id: int,
dtype: Dtype,
parallel_input_port_ids: Optional[list[int]] = None,
) -> None:
"""
Adds a directed edge between two `NNCFNode`s that are already present in the graph.
Expand All @@ -565,7 +540,6 @@ def add_edge_between_nncf_nodes(
:param output_port_id: Specifies the index among the possible outputs of the `from_node_id` node' that this
tensor should correspond to.
:param dtype: The data type of the tensor.
:param parallel_input_port_ids: Input ports for parallel edges, if any should be present for this edge.
"""
from_node_key = self._node_id_to_key_dict[from_node_id]
to_node_key = self._node_id_to_key_dict[to_node_id]
Expand All @@ -581,6 +555,19 @@ def add_edge_between_nncf_nodes(
if to_node_id in self._input_nncf_nodes:
err_reason = "cannot add edges *to* input nodes"

exist_edges = self._nx_graph.get_edge_data(from_node_key, to_node_key)
if exist_edges is not None:
for edge in exist_edges.values():
if (
edge[NNCFGraph.INPUT_PORT_ID_EDGE_ATTR] == input_port_id
and edge[NNCFGraph.OUTPUT_PORT_ID_EDGE_ATTR] == output_port_id
):
err_reason = (
"two edges have the same pair of port ids:"
f" input_port_id({input_port_id}) output_port_id({output_port_id})"
)
break

if err_reason is not None:
msg = f"Cannot add edge from {from_node_key} to {to_node_key} - {err_reason}!"
raise ValueError(msg)
Expand All @@ -590,7 +577,6 @@ def add_edge_between_nncf_nodes(
NNCFGraph.INPUT_PORT_ID_EDGE_ATTR: input_port_id,
NNCFGraph.OUTPUT_PORT_ID_EDGE_ATTR: output_port_id,
NNCFGraph.DTYPE_EDGE_ATTR: dtype,
NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR: [] if parallel_input_port_ids is None else parallel_input_port_ids,
}
self._nx_graph.add_edge(from_node_key, to_node_key, **attrs)

Expand Down Expand Up @@ -621,7 +607,7 @@ def get_graph_for_structure_analysis(self, extended: bool = False) -> nx.DiGraph
:param extended: whether the graph edges should have attributes: shape of the tensor and tensor primitive type.
:return: An nx.DiGraph to be used for structure analysis
"""
out_graph = nx.DiGraph()
out_graph = nx.MultiDiGraph()
for node_name, node in self._nx_graph.nodes.items():
attrs_node = {"id": str(node[NNCFNode.ID_NODE_ATTR]), "type": node[NNCFNode.NODE_TYPE_ATTR]}
for attr in ["color", "label", "style"]:
Expand All @@ -630,13 +616,10 @@ def get_graph_for_structure_analysis(self, extended: bool = False) -> nx.DiGraph

out_graph.add_node(node_name, **attrs_node)

for u, v in self._nx_graph.edges:
edge = self._nx_graph.edges[u, v]
for u, v, k in self._nx_graph.edges:
edge = self._nx_graph.edges[u, v, k]
attrs_edge = {}
label = {}
if edge[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]:
label["parallel_input_port_ids"] = edge[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]

if extended:
if edge[NNCFGraph.DTYPE_EDGE_ATTR] is Dtype.INTEGER:
attrs_edge["style"] = "dashed"
Expand Down Expand Up @@ -735,7 +718,6 @@ def get_nncf_graph_pattern_io(self, match: list[str]) -> NNCFGraphPatternIO:
output_port_id=data[NNCFGraph.OUTPUT_PORT_ID_EDGE_ATTR],
tensor_shape=data[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR],
dtype=data[NNCFGraph.DTYPE_EDGE_ATTR],
parallel_input_port_ids=data[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR],
)
if from_node_key in match:
output_nncf_edges.append(nncf_edge)
Expand All @@ -747,36 +729,39 @@ def get_nncf_graph_pattern_io(self, match: list[str]) -> NNCFGraphPatternIO:

return NNCFGraphPatternIO(input_nncf_edges, output_nncf_edges)

def get_nx_edge(self, node_u: NNCFNode, node_v: NNCFNode) -> OutEdgeView:
def get_nx_edges(self, node_u: NNCFNode, node_v: NNCFNode) -> OutEdgeView:
nx_node_u = self._nx_graph.nodes[self._node_id_to_key_dict[node_u.node_id]]
nx_node_v = self._nx_graph.nodes[self._node_id_to_key_dict[node_v.node_id]]
return self._nx_graph.edges[nx_node_u["key"], nx_node_v["key"]]
return self._nx_graph.get_edge_data(nx_node_u["key"], nx_node_v["key"])

def get_nodes_count(self) -> int:
return int(self._nx_graph.number_of_nodes())

def get_edge(self, from_node: NNCFNode, to_node: NNCFNode) -> NNCFGraphEdge:
def get_edges(self, from_node: NNCFNode, to_node: NNCFNode) -> list[NNCFGraphEdge]:
"""
Returns an NNCFGraphEdge object that corresponds to an edge connecting two given NNCFNodes in this
Returns a list of NNCFGraphEdge objects that corresponds to edges connecting two given NNCFNodes in this
graph.

:param from_node: The NNCFNode in this graph that sources the edge.
:param to_node: The NNCFNode in this graph that sinks the edge.
:return: The NNCFGraphEdge object representing the edge between `from_node` and `to_node`.
"""
data = self.get_nx_edge(from_node, to_node)
return NNCFGraphEdge(
from_node,
to_node,
data[NNCFGraph.INPUT_PORT_ID_EDGE_ATTR],
data[NNCFGraph.OUTPUT_PORT_ID_EDGE_ATTR],
data[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR],
data[NNCFGraph.DTYPE_EDGE_ATTR],
data[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR],
)
:return: The list of NNCFGraphEdge objects representing the edges between `from_node` and `to_node`.
"""
nx_edges = self.get_nx_edges(from_node, to_node)
return [
NNCFGraphEdge(
from_node,
to_node,
data[NNCFGraph.INPUT_PORT_ID_EDGE_ATTR],
data[NNCFGraph.OUTPUT_PORT_ID_EDGE_ATTR],
data[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR],
data[NNCFGraph.DTYPE_EDGE_ATTR],
)
for data in nx_edges.values()
]

def get_all_edges(self) -> Generator[NNCFGraphEdge, None, None]:
for nx_edge in self._nx_graph.in_edges:
yield self.get_edge(self.get_node_by_key(nx_edge[0]), self.get_node_by_key(nx_edge[1]))
yield from self.get_edges(self.get_node_by_key(nx_edge[0]), self.get_node_by_key(nx_edge[1]))

def remove_nodes_from(self, nodes: Iterable[NNCFNode]) -> None:
"""
Expand Down
14 changes: 4 additions & 10 deletions src/nncf/common/insertion_point_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ class InsertionPointGraph(nx.DiGraph): # type: ignore
"""
This graph is built from the NNCFGraph representation of the model control flow graph and adds ephemeral
"insertion point nodes" into the NNCF model graph representation corresponding to operator pre- and
post-hooks. Module pre-op and post-op insertion points are currently not reflected here, but they are
probably not required for quantizing activations, for which the quantizer propagation makes sense.
This "insertion point graph" representation is useful for quantizer propagation and for referencing
post-hooks. This "insertion point graph" representation is useful for quantizer propagation and for referencing
the compression algorithm hooks to the model operations to which they are applied to.
"""

Expand Down Expand Up @@ -118,13 +116,11 @@ def __init__(
for edge in self._base_nx_graph.edges:
input_port_id = self._base_nx_graph.edges[edge][NNCFGraph.INPUT_PORT_ID_EDGE_ATTR]
dtype = self._base_nx_graph.edges[edge][NNCFGraph.DTYPE_EDGE_ATTR]
parallel_input_port_ids = self._base_nx_graph.edges[edge][NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]
from_node, to_node = edge
from_node, to_node, _ = edge

attrs = {
INPUT_PORT_ID: input_port_id,
self.IS_INTEGER_PATH_EDGE_ATTR: dtype is Dtype.INTEGER,
NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR: parallel_input_port_ids,
}
self.add_edge(from_node, to_node, **attrs)

Expand All @@ -148,8 +144,6 @@ def __init__(
for edge in in_edges:
input_port_id = self.edges[edge][INPUT_PORT_ID]
input_port_id_vs_edge[input_port_id] = edge
for parallel_input_port_id in self.edges[edge][NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]:
input_port_id_vs_edge[parallel_input_port_id] = edge

encountered_input_edges = set()
for pre_hook_point in pre_hook_ips:
Expand Down Expand Up @@ -235,8 +229,8 @@ def _get_default_pre_hook_ip_list(nncf_graph: NNCFGraph) -> list[PreHookInsertio
pred_nodes = nncf_graph.get_previous_nodes(nncf_node)

for pred_node in pred_nodes:
input_edge = nncf_graph.get_edge(pred_node, nncf_node)
input_port_ids = [input_edge.input_port_id] + input_edge.parallel_input_port_ids
input_edges = nncf_graph.get_edges(pred_node, nncf_node)
input_port_ids = [edge.input_port_id for edge in input_edges]
node_name = nncf_node.node_name
for input_port_id in input_port_ids:
allowed_pre_hook_insertion_points.append(PreHookInsertionPoint(node_name, input_port_id))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,10 @@ def _get_activation_port_id(node: NNCFNode, graph: NNCFGraph) -> int:
"""
activation_ports = []
for prev_node in graph.get_previous_nodes(node):
edge = graph.get_edge(prev_node, node)
if prev_node.metatype in CONST_NOOP_METATYPES or edge.input_port_id in node.metatype.weight_port_ids:
continue
activation_ports.append(edge.input_port_id)
for edge in graph.get_edges(prev_node, node):
if prev_node.metatype in CONST_NOOP_METATYPES or edge.input_port_id in node.metatype.weight_port_ids:
continue
activation_ports.append(edge.input_port_id)
if len(activation_ports) != 1:
msg = f'Cannot find activation port for node "{node}".'
raise nncf.InternalError(msg)
Expand Down
22 changes: 9 additions & 13 deletions src/nncf/openvino/graph/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,19 +109,15 @@ def _add_edges_to_nncf_graph(model: ov.Model, graph: NNCFGraph) -> None:
output_node_id = graph.get_node_by_name(out_node.get_friendly_name()).node_id
nncf_dtype = GraphConverter.convert_to_nncf_dtype(out.get_element_type())

parallel_inputs = None
if len(inputs) > 1:
parallel_inputs = [inp.get_index() for inp in inputs[1:]]

graph.add_edge_between_nncf_nodes(
from_node_id=in_node_id,
to_node_id=output_node_id,
tensor_shape=tensor_shape,
input_port_id=inputs[0].get_index(),
output_port_id=output_port_id,
dtype=nncf_dtype,
parallel_input_port_ids=parallel_inputs,
)
for inp in inputs:
graph.add_edge_between_nncf_nodes(
from_node_id=in_node_id,
to_node_id=output_node_id,
tensor_shape=tensor_shape,
input_port_id=inp.get_index(),
output_port_id=output_port_id,
dtype=nncf_dtype,
)

@staticmethod
def _add_nncf_node(node: ov.Node, graph: NNCFGraph) -> None:
Expand Down
10 changes: 8 additions & 2 deletions src/nncf/quantization/algorithms/smooth_quant/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,16 @@ def _group_nodes_by_source(
groups = defaultdict(list)
for node_to_smooth, input_act_port, shape in nodes_to_smooth:
source_node = nncf_graph.get_input_edge_by_port_id(node_to_smooth, input_act_port).from_node
edge = nncf_graph.get_edge(source_node, node_to_smooth)
edges = nncf_graph.get_edges(source_node, node_to_smooth)
if len(edges) > 1:
msg = (
f"Only one node expected between {source_node.node_name} and {node_to_smooth.node_name},"
f" but {len(edges)} found"
)
raise nncf.InternalError(msg)
# Such group_id (with node, ports, and shape as a hash) allows us to be confident
# that all sensitive parameters are equal for successor nodes are equal.
group_id = (source_node, input_act_port, edge.output_port_id, shape)
group_id = (source_node, input_act_port, edges[0].output_port_id, shape)
groups[group_id].append(node_to_smooth)

return groups
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ def is_node_with_weights(node: NNCFNode, graph: NNCFGraph) -> bool:
):
return False
for prev_node in graph.get_previous_nodes(node):
edge = graph.get_edge(prev_node, node)
if edge.input_port_id not in node.metatype.weight_port_ids:
continue
weight_node = find_const_node_in_constant_subgraph(prev_node, graph)
if weight_node is not None:
return True
for edge in graph.get_edges(prev_node, node):
if edge.input_port_id not in node.metatype.weight_port_ids:
continue
weight_node = find_const_node_in_constant_subgraph(prev_node, graph)
if weight_node is not None:
return True
return False

@staticmethod
Expand All @@ -116,9 +116,9 @@ def get_weight_names_and_port_ids(node: NNCFNode, graph: NNCFGraph) -> list[tupl
weight_node = find_const_node_in_constant_subgraph(prev_node, graph)
if weight_node is None:
continue
edge = graph.get_edge(prev_node, node)
if edge.input_port_id in node.metatype.weight_port_ids:
weight_port_ids.append((weight_node.layer_attributes.name, edge.input_port_id))
for edge in graph.get_edges(prev_node, node):
if edge.input_port_id in node.metatype.weight_port_ids:
weight_port_ids.append((weight_node.layer_attributes.name, edge.input_port_id))
return weight_port_ids

@staticmethod
Expand Down Expand Up @@ -154,8 +154,8 @@ def get_activation_port_id(node: NNCFNode, graph: NNCFGraph) -> int:
for prev_node in graph.get_previous_nodes(node):
if prev_node.metatype in CONST_NOOP_METATYPES:
continue
edge = graph.get_edge(prev_node, node)
activation_ports.append(edge.input_port_id)
for edge in graph.get_edges(prev_node, node):
activation_ports.append(edge.input_port_id)
assert len(activation_ports) == 1
return activation_ports[0]

Expand Down
Loading
Loading