-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathabstraction_graph_toy_example.py
66 lines (57 loc) · 2.5 KB
/
abstraction_graph_toy_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Create an abstraction graph for the toy example."""
import numpy as np
from treelib import Tree
def make_abstraction_graph(misaligned=False):
abstraction_graph = Tree()
abstraction_graph.create_node(tag='root', identifier='root', parent=None, data=None)
# Add superclass nodes
abstraction_graph.create_node(tag='parent0', identifier='parent0', parent='root', data=None)
abstraction_graph.create_node(tag='parent1', identifier='parent1', parent='root', data=None)
# Add class nodes
abstraction_graph.create_node(tag='child0',
identifier='child0',
parent='parent0',
data=None)
if misaligned:
abstraction_graph.create_node(tag='child1',
identifier='child1',
parent='parent0',
data=None)
else:
abstraction_graph.create_node(tag='child1',
identifier='child1',
parent='parent1',
data=None)
abstraction_graph.create_node(tag='child2',
identifier='child2',
parent='parent1',
data=None)
return abstraction_graph
def show_abstraction_graph(abstraction_graph, hide_zeros=False):
string = abstraction_graph.show(stdout=False)
for node_id, node in abstraction_graph.nodes.items():
node_value = node.data
if node_value is not None:
node_value = round(node_value, 2)
if node_value == 0 and hide_zeros:
node_value = ''
else:
node_value = f'{node_value:.2f}'
string = string.replace(f'{node_id}\n', f'{node_id} ({node_value})\n')
return string
def propagate(outputs, abstraction_graph):
"""Propagate model outputs through the abstraction_graph."""
# Assign values to the leaves of the abstraction_graph
for i, value in enumerate(outputs):
name = f'child{i}'
node = abstraction_graph.get_node(name)
node.data = value
# Propagate values up the abstraction_graph
level = abstraction_graph.depth() - 1 # leaf level = depth
while level >= 0:
nodes = abstraction_graph.filter_nodes(lambda x: abstraction_graph.depth(x) == level)
for node in nodes:
reachable_leaves = abstraction_graph.leaves(node.identifier)
node.data = np.sum([leaf.data for leaf in reachable_leaves])
level -= 1
return abstraction_graph