Skip to content

Commit 6cc9e99

Browse files
authored
Compression: automatically propagate constants and remove unused nodes (#8)
* Added to_dot_file to python interface * Added circuit.nb_root_nodes() * Re-enabling remove_unused_nodes * Added tests for removing unused nodes. * Added comment
1 parent 8851cd1 commit 6cc9e99

File tree

4 files changed

+148
-11
lines changed

4 files changed

+148
-11
lines changed

src/circuit.cpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ Node* Circuit::add_node_level(Node* node) {
6767
}
6868

6969
Node* Circuit::add_node_level_compressed(Node* node) {
70-
return add_node_level(node);
70+
// return add_node_level(node); // To disable compression.
7171
if (node->type != NodeType::And && node->type != NodeType::Or)
7272
return add_node_level(node);
7373

@@ -206,6 +206,8 @@ size_t Circuit::max_layer_width() const {
206206
}
207207

208208
void Circuit::remove_unused_nodes() {
209+
// Should be run before adding a final root layer;
210+
// because it might change ix's.
209211
std::vector<std::vector<bool>> used;
210212
used.reserve(nb_layers());
211213
for (const auto& layer : layers)
@@ -238,7 +240,8 @@ void Circuit::remove_unused_nodes() {
238240
}
239241
}
240242

241-
// Clean-up: last layers can be empty (but intermediate ones should not)
243+
// Clean-up: last layers can be empty, pop those.
244+
// Intermediate layers can not be empty because we use dummy nodes.
242245
for (std::size_t i = nb_layers()-1; i > 0; --i) {
243246
if (layers[i].empty()) {
244247
layers.pop_back();
@@ -257,14 +260,10 @@ void Circuit::remove_unused_nodes() {
257260
for (auto &node : layers[i])
258261
node->ix = index++;
259262
}
260-
// Clean-up: last layer has fixed ix order
261-
for(size_t i = 0; i < roots.size(); ++i)
262-
roots[i]->ix = i;
263+
263264

264265
#ifndef NDEBUG
265266
// print_circuit();
266-
// assert, last layer should only contain root nodes.
267-
assert(roots.size() == layers[nb_layers()-1].size());
268267

269268
if (layers.size() > 2) {
270269
// check for each layer, for each node, whether the idx
@@ -363,6 +362,15 @@ void to_dot_file(Circuit& circuit, const std::string& filename) {
363362
file << " " << node->hash << " [label=\"" << node->get_label() << "\"]" << std::endl;
364363
}
365364
}
365+
// Group nodes per layer
366+
// using { rank=same; 1; 2; } to group node 1 and 2
367+
for (const auto &layer: circuit.layers) {
368+
file << " { rank=same; ";
369+
for (const auto *node : layer) {
370+
file << node->hash << "; ";
371+
}
372+
file << "}" << std::endl;
373+
}
366374
file << "}" << std::endl;
367375
}
368376

@@ -402,12 +410,13 @@ void Circuit::add_root_layer() {
402410

403411

404412
void cleanup(void* data) noexcept {
405-
delete[] static_cast<long int*>(data);
413+
delete[] static_cast<long int*>(data);
406414
}
407415

408416

409417
std::pair<Arrays, Arrays> Circuit::tensorize() {
410-
add_root_layer();
418+
remove_unused_nodes();
419+
add_root_layer();
411420
//print_circuit(); // Helpful for debugging small circuits
412421

413422
// per layer, a vector of size the number of children (but children can count twice
@@ -469,10 +478,14 @@ nb::class_<Circuit>(m, "Circuit")
469478
.def("add_D4_from_file", &Circuit::add_D4_from_file, "filename"_a, "true_lits"_a = std::vector<int>(), "false_lits"_a = std::vector<int>())
470479
.def("get_indices", &Circuit::get_indices)
471480
.def("nb_nodes", &Circuit::nb_nodes, "number of nodes in the circuit")
481+
.def("nb_root_nodes", &Circuit::nb_root_nodes, "number of root nodes in the circuit")
472482
.def("true_node", &Circuit::true_node, "adds a true node to the circuit, and returns a pointer")
473483
.def("false_node", &Circuit::false_node, "adds a false node to the circuit, and returns a pointer")
474484
.def("literal_node", &Circuit::literal_node, "adds a literal node to the circuit ,and returns a pointer")
475485
.def("or_node", &Circuit::or_node, "children"_a, "adds an or node to the circuit, and returns a pointer")
476486
.def("and_node", &Circuit::and_node, "children"_a, "adds an and node to the circuit, and returns a pointer")
477-
.def("set_root", &Circuit::set_root, "root"_a, "marks a node pointer as root");
487+
.def("set_root", &Circuit::set_root, "root"_a, "marks a node pointer as root")
488+
.def("remove_unused_nodes", &Circuit::remove_unused_nodes, "Removes unused non-root nodes from the circuit.\nCareful! This invalidates any NodePtr refering to an unused node (i.e., a node not conneected to a root node).");
489+
490+
m.def("to_dot_file", &to_dot_file, "circuit"_a, "filename"_a, "Write the given circuit as dot format to a file");
478491
}

src/circuit.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,10 @@ class Circuit {
195195
return count;
196196
}
197197

198+
std::size_t nb_root_nodes() const {
199+
return roots.size();
200+
}
201+
198202
/**
199203
* For debugging purposes;
200204
* prints every node of each layer

src/klay/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from time import perf_counter
33
import random
44
from array import array
5+
# noinspection PyUnresolvedReferences
6+
from .nanobind_ext import to_dot_file
57

68
import torch
79
try:
@@ -248,4 +250,12 @@ def jax_weights(nb_vars, semiring = "log"):
248250
weights, neg_weights = python_weights(nb_vars, semiring)
249251
weights = jax.numpy.array(weights)
250252
neg_weights = jax.numpy.array(neg_weights)
251-
return weights, neg_weights
253+
return weights, neg_weights
254+
255+
def circuit_to_dot(circuit, filename):
256+
"""
257+
Write the given circuit as dot format to a file.
258+
:param circuit: The circuit to write as dot format.
259+
:param filename: The filepath to write to.
260+
"""
261+
to_dot_file(circuit, filename)

tests/test_compression.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import klay
2+
3+
4+
def test_propagate_simple_true():
5+
c = klay.Circuit()
6+
t = c.true_node()
7+
f = c.false_node()
8+
l1, l2 = c.literal_node(1), c.literal_node(2)
9+
10+
# test on input node
11+
assert c.nb_nodes() == 4
12+
and_node1 = c.and_node([t, l1]) # l1
13+
and_node2 = c.and_node([l1, t]) # l1
14+
or_node1 = c.or_node([l1, t]) # t
15+
or_node2 = c.or_node([t, l1]) # t
16+
assert c.nb_nodes() == 4
17+
18+
# test on intermediate node
19+
l1_l2 = c.and_node([l1, l2])
20+
assert c.nb_nodes() == 5
21+
and_node1 = c.and_node([t, l1_l2]) # l1
22+
and_node2 = c.and_node([l1_l2, t]) # l1
23+
or_node1 = c.or_node([l1_l2, t]) # t
24+
or_node2 = c.or_node([t, l1_l2]) # t
25+
assert c.nb_nodes() == 5
26+
27+
28+
def test_propagate_simple_false():
29+
c = klay.Circuit()
30+
t = c.true_node()
31+
f = c.false_node()
32+
l1, l2 = c.literal_node(1), c.literal_node(2)
33+
34+
# test on input node
35+
assert c.nb_nodes() == 4
36+
and_node1 = c.and_node([f, l1]) # f
37+
and_node2 = c.and_node([l1, f]) # f
38+
or_node1 = c.or_node([l1, f]) # l1
39+
or_node2 = c.or_node([f, l1]) # l1
40+
assert c.nb_nodes() == 4
41+
42+
# test on intermediate node
43+
l1_l2 = c.and_node([l1, l2])
44+
assert c.nb_nodes() == 5
45+
and_node1 = c.and_node([f, l1_l2]) # f
46+
and_node2 = c.and_node([l1_l2, f]) # f
47+
or_node1 = c.or_node([l1_l2, f]) # l1 & l2
48+
or_node2 = c.or_node([f, l1_l2]) # l1 & l2
49+
assert c.nb_nodes() == 5
50+
51+
52+
def test_propagate_simple_ternary():
53+
""" test ternary nodes """
54+
c = klay.Circuit()
55+
t = c.true_node()
56+
f = c.false_node()
57+
l1, l2 = c.literal_node(1), c.literal_node(2)
58+
59+
# test on true
60+
assert c.nb_nodes() == 4
61+
and_node1 = c.and_node([t, l1, l2]) # l1 & l2
62+
assert c.nb_nodes() == 5
63+
and_node2 = c.and_node([l2, t, l1]) # l1 & l2
64+
assert c.nb_nodes() == 5
65+
or_node1 = c.or_node([l1, t, l2]) # t
66+
assert c.nb_nodes() == 5
67+
or_node2 = c.or_node([l2, l1, t]) # t
68+
assert c.nb_nodes() == 5
69+
70+
# test on false
71+
and_node3 = c.and_node([f, l1, l2]) # f
72+
assert c.nb_nodes() == 5
73+
and_node4 = c.and_node([l2, f, l1]) # f
74+
assert c.nb_nodes() == 5
75+
or_node3 = c.or_node([l1, f, l2]) # l1 | l2
76+
assert c.nb_nodes() == 8, "Expected 8 nodes instead of 6, because l1 and l2 require dummy nodes for the OR-node."
77+
or_node4 = c.or_node([l2, l1, f]) # l1 | l2
78+
assert c.nb_nodes() == 8
79+
80+
81+
def test_removing_useless_nodes1():
82+
c = klay.Circuit()
83+
l1, l2, l3 = c.literal_node(1), c.literal_node(2), c.literal_node(3)
84+
assert c.nb_nodes() == 3
85+
and1 = c.and_node([l1, l2])
86+
assert c.nb_nodes() == 4
87+
or1 = c.or_node([and1, l3])
88+
assert c.nb_nodes() == 6 # or1 + 1 dummy node
89+
c.set_root(and1)
90+
# and1 is root node; but or1 is in a layer above, unused.
91+
assert c.nb_nodes() == 6
92+
c.remove_unused_nodes() # should remove or1 + 1 dummy node
93+
assert c.nb_nodes() == 4, f"Expected 4 nodes instead of {c.nb_nodes()}"
94+
95+
96+
def test_removing_useless_nodes2():
97+
c = klay.Circuit()
98+
l1, l2, l3 = c.literal_node(1), c.literal_node(2), c.literal_node(3)
99+
assert c.nb_nodes() == 3
100+
and1 = c.and_node([l1, l2])
101+
assert c.nb_nodes() == 4
102+
or1 = c.or_node([and1, l3])
103+
assert c.nb_nodes() == 6 # or1 + 1 dummy node
104+
and2 = c.and_node([l1, l3]) # useless
105+
assert c.nb_nodes() == 7
106+
or2 = c.or_node([l1, l2]) # useless
107+
assert c.nb_nodes() == 10 # or2 + 2 dummy nodes
108+
c.set_root(or1)
109+
c.remove_unused_nodes() # should remove `and2`, `or2`, and 2 dummy nodes
110+
assert c.nb_nodes() == 6, f"Expected 5 nodes instead of {c.nb_nodes()}"

0 commit comments

Comments
 (0)