Skip to content

Commit 8fbb617

Browse files
committed
add node merge
1 parent 217b6b1 commit 8fbb617

File tree

2 files changed

+72
-3
lines changed

2 files changed

+72
-3
lines changed

src/circuit.cpp

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,31 @@ Node* Circuit::add_node_level_compressed(Node* node) {
107107
return add_node_level(node);
108108
}
109109

110+
111+
Node* Circuit::add_node_merge(Node *node) {
112+
Node* new_node; // The hash of the first node might be spoiled
113+
if (node->type == NodeType::Or) {
114+
new_node = Node::createOrNode();
115+
} else if (node->type == NodeType::And) {
116+
new_node = Node::createAndNode();
117+
} else {
118+
return add_node_level_compressed(node);
119+
}
120+
121+
emhash8::HashSet<Node*, NodeHash, NodeEqual> children_set = {};
122+
for (auto child: node->children) {
123+
while (child->layer < node->layer - 1)
124+
child = add_node(child->dummy_parent());
125+
children_set.insert(child);
126+
}
127+
delete node;
128+
129+
for (auto child: children_set)
130+
new_node->add_child(child);
131+
132+
return add_node_level_compressed(new_node);
133+
}
134+
110135
/**
111136
* Auxiliary method for Circuit::add_sdd_from_file
112137
*/
@@ -442,20 +467,52 @@ std::pair<Arrays, Arrays> Circuit::tensorize() {
442467
}
443468

444469

470+
NodePtr Circuit::disjoin(std::vector<NodePtr> nodes) {
471+
Node* or_node = Node::createOrNode();
472+
for (NodePtr node_ptr: nodes) {
473+
Node* node = node_ptr.get();
474+
if (node->type == NodeType::Or) {
475+
for (Node* ch: node->children)
476+
or_node->add_child(ch);
477+
} else {
478+
or_node->add_child(node);
479+
}
480+
}
481+
return NodePtr(add_node_merge(or_node));
482+
}
483+
484+
485+
NodePtr Circuit::conjoin(std::vector<NodePtr> nodes) {
486+
Node* and_node = Node::createAndNode();
487+
for (NodePtr node_ptr: nodes) {
488+
Node* node = node_ptr.get();
489+
if (node->type == NodeType::And) {
490+
for (Node* ch: node->children)
491+
and_node->add_child(ch);
492+
} else {
493+
and_node->add_child(node);
494+
}
495+
}
496+
return NodePtr(add_node_merge(and_node));
497+
}
498+
499+
445500

446501
NB_MODULE(nanobind_ext, m) {
447502
m.doc() = "Layerize arithmetic circuits";
448503

449504
nb::class_<NodePtr>(m, "NodePtr")
450505
.def("__repr__", &NodePtr::to_string)
451506
.def(nb::self == nb::self)
452-
.def("__hash__", [](const NodePtr &a) {return (std::size_t) a.get();});
507+
.def("__hash__", &NodePtr::as_int);
453508

454509
nb::class_<Circuit>(m, "Circuit", "Circuits are the main class added by KLay, and require no arguments to construct.\n\n:code:`circuit = klay.Circuit()` ")
455510
.def(nb::init<>())
456511
.def("add_sdd_from_file", &Circuit::add_sdd_from_file, "filename"_a, "true_lits"_a = std::vector<int>(), "false_lits"_a = std::vector<int>(), "Add an SDD circuit from file.\n\n:param filename:\n\tPath to the :code:`.sdd` file on disk.\n:param true_lits:\n\tList of literals that are always true and should get propagated away.\n:param false_lits:\n\tList of literals that are always false and should get propagated away.")
457512
.def("add_d4_from_file", &Circuit::add_d4_from_file, "filename"_a, "true_lits"_a = std::vector<int>(), "false_lits"_a = std::vector<int>(), "Add an NNF circuit in the D4 format from file.\n\n:param filename:\n\tPath to the :code:`.nnf` file on disk.\n:param true_lits:\n\tList of literals that are always true and should get propagated away.\n:param false_lits:\n\tList of literals that are always false and should get propagated away.")
458513
.def("_get_indices", &Circuit::get_indices)
514+
.def("disjoin", &Circuit::disjoin)
515+
.def("conjoin", &Circuit::conjoin)
459516
.def("nb_nodes", &Circuit::nb_nodes, "Number of nodes in the circuit.")
460517
.def("nb_root_nodes", &Circuit::nb_root_nodes, "Number of root nodes in the circuit.")
461518
.def("true_node", &Circuit::true_node, "Adds a true node to the circuit, and returns a pointer to this node.")

src/circuit.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <sstream>
1313
#include <vector>
1414
#include <list>
15+
#include <cstdint>
1516

1617
#include "node.h"
1718
#include "hash_set8.hpp"
@@ -31,16 +32,19 @@ class NodePtr {
3132
}
3233

3334
std::string to_string() const {
34-
const void * address = static_cast<const void*>(ptr);
3535
std::stringstream ss;
36-
ss << "NodePtr(" << address << ")";
36+
ss << "NodePtr(" << this->as_int() << ")";
3737
return ss.str();
3838
}
3939

4040
bool operator==(NodePtr other) const {
4141
return this->ptr == other.ptr;
4242
}
4343

44+
std::uintptr_t as_int() const {
45+
return reinterpret_cast<std::uintptr_t>(ptr);
46+
}
47+
4448
private:
4549
Node* ptr;
4650
};
@@ -128,6 +132,11 @@ class Circuit {
128132
*/
129133
Node* add_node_level_compressed(Node* node);
130134

135+
/**
136+
* De-duplicate the children of the node and add it to the circuit.
137+
*/
138+
Node* add_node_merge(Node* node);
139+
131140
/**
132141
* Get the corresponding node in the circuit.
133142
* This may be a different node instance with the same hash and
@@ -225,4 +234,7 @@ class Circuit {
225234
}
226235
return NodePtr(add_node_level_compressed(node));
227236
}
237+
238+
NodePtr disjoin(std::vector<NodePtr> nodes);
239+
NodePtr conjoin(std::vector<NodePtr> nodes);
228240
};

0 commit comments

Comments
 (0)