Skip to content

Commit dfe7e6c

Browse files
committed
remove the child merge stuff
1 parent 1f75563 commit dfe7e6c

File tree

4 files changed

+2
-79
lines changed

4 files changed

+2
-79
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ __pycache__
1212
cmake-build-debug
1313
.venv
1414
tmp
15-
data
15+
data
16+
_build

src/circuit.cpp

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -108,30 +108,6 @@ Node* Circuit::add_node_level_compressed(Node* node) {
108108
}
109109

110110

111-
Node* Circuit::add_node_merge(Node *node) {
112-
Node* new_node; // The hash of the first node might be spoiled
113-
if (node->type == NodeType::Or) {
114-
new_node = Node::createOrNode();
115-
} else if (node->type == NodeType::And) {
116-
new_node = Node::createAndNode();
117-
} else {
118-
return add_node_level_compressed(node);
119-
}
120-
121-
emhash8::HashSet<Node*, NodeHash, NodeEqual> children_set = {};
122-
for (auto child: node->children) {
123-
while (child->layer < node->layer - 1)
124-
child = add_node(child->dummy_parent());
125-
children_set.insert(child);
126-
}
127-
delete node;
128-
129-
for (auto child: children_set)
130-
new_node->add_child(child);
131-
132-
return add_node_level_compressed(new_node);
133-
}
134-
135111
/**
136112
* Auxiliary method for Circuit::add_sdd_from_file
137113
*/
@@ -492,37 +468,6 @@ std::pair<Arrays, Arrays> Circuit::get_indices() {
492468
}
493469

494470

495-
NodePtr Circuit::disjoin(std::vector<NodePtr> nodes) {
496-
Node* or_node = Node::createOrNode();
497-
for (NodePtr node_ptr: nodes) {
498-
Node* node = node_ptr.get();
499-
if (node->type == NodeType::Or) {
500-
for (Node* ch: node->children)
501-
or_node->add_child(ch);
502-
} else {
503-
or_node->add_child(node);
504-
}
505-
}
506-
return NodePtr(add_node_merge(or_node));
507-
}
508-
509-
510-
NodePtr Circuit::conjoin(std::vector<NodePtr> nodes) {
511-
Node* and_node = Node::createAndNode();
512-
for (NodePtr node_ptr: nodes) {
513-
Node* node = node_ptr.get();
514-
if (node->type == NodeType::And) {
515-
for (Node* ch: node->children)
516-
and_node->add_child(ch);
517-
} else {
518-
and_node->add_child(node);
519-
}
520-
}
521-
return NodePtr(add_node_merge(and_node));
522-
}
523-
524-
525-
526471
NB_MODULE(nanobind_ext, m) {
527472
m.doc() = "Layerize arithmetic circuits";
528473

@@ -537,8 +482,6 @@ nb::class_<Circuit>(m, "Circuit", "Circuits are the main class added by KLay, an
537482
.def("add_sdd_from_file", &Circuit::add_sdd_from_file, "filename"_a, "true_lits"_a = std::vector<int>(), "false_lits"_a = std::vector<int>(), "Add an SDD circuit from file.\n\n:param filename:\n\tPath to the :code:`.sdd` file on disk.\n:param true_lits:\n\tList of literals that are always true and should get propagated away.\n:param false_lits:\n\tList of literals that are always false and should get propagated away.")
538483
.def("add_d4_from_file", &Circuit::add_d4_from_file, "filename"_a, "true_lits"_a = std::vector<int>(), "false_lits"_a = std::vector<int>(), "Add an NNF circuit in the D4 format from file.\n\n:param filename:\n\tPath to the :code:`.nnf` file on disk.\n:param true_lits:\n\tList of literals that are always true and should get propagated away.\n:param false_lits:\n\tList of literals that are always false and should get propagated away.")
539484
.def("_get_indices", &Circuit::get_indices)
540-
.def("disjoin", &Circuit::disjoin)
541-
.def("conjoin", &Circuit::conjoin)
542485
.def("nb_nodes", &Circuit::nb_nodes, "Number of nodes in the circuit.")
543486
.def("nb_root_nodes", &Circuit::nb_root_nodes, "Number of root nodes in the circuit.")
544487
.def("true_node", &Circuit::true_node, "Adds a true node to the circuit, and returns a pointer to this node.")

src/circuit.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,6 @@ class Circuit {
132132
*/
133133
Node* add_node_level_compressed(Node* node);
134134

135-
/**
136-
* De-duplicate the children of the node and add it to the circuit.
137-
*/
138-
Node* add_node_merge(Node* node);
139-
140135
/**
141136
* Get the corresponding node in the circuit.
142137
* This may be a different node instance with the same hash and
@@ -232,7 +227,4 @@ class Circuit {
232227
}
233228
return NodePtr(add_node_level_compressed(node));
234229
}
235-
236-
NodePtr disjoin(std::vector<NodePtr> nodes);
237-
NodePtr conjoin(std::vector<NodePtr> nodes);
238230
};

tests/test_manual.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,6 @@ def test_or_node():
1919
assert m(weights) == 0.4 + (1 - 0.8)
2020

2121

22-
def test_disjoin_conjoin():
23-
c = klay.Circuit()
24-
l1, l2, l3 = c.literal_node(1), c.literal_node(-2), c.literal_node(3)
25-
or_node1 = c.disjoin([l1, l2, l2])
26-
or_node2 = c.disjoin([l3, or_node1, l3, or_node1])
27-
c.set_root(or_node2)
28-
29-
m = c.to_torch_module(semiring='real')
30-
weights = torch.tensor([0.4, 0.8, 0.5])
31-
expected_result = torch.tensor(0.4 + (1 - 0.8) + 0.5)
32-
assert torch.allclose(m(weights), expected_result)
33-
34-
3522
def test_probabilistic():
3623
c = klay.Circuit()
3724
l1, l2, l3 = c.literal_node(1), c.literal_node(-2), c.literal_node(3)

0 commit comments

Comments
 (0)