Skip to content

Commit ebb01c0

Browse files
committed
Fix BddTrait logic issue, using wrong conditions
We were using the wrong condition ordering in BddTrait after compiling a Bdd from the CFG, leading to a totally broken BDD. Also adds some tests, fixes, and generalizes BddTrait transforms
1 parent 1487287 commit ebb01c0

13 files changed

Lines changed: 635 additions & 247 deletions

File tree

smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import java.io.UncheckedIOException;
1111
import java.io.Writer;
1212
import java.nio.charset.StandardCharsets;
13-
import java.util.Arrays;
1413
import java.util.Objects;
1514
import java.util.function.Consumer;
1615
import software.amazon.smithy.rulesengine.logic.ConditionEvaluator;
@@ -44,6 +43,7 @@ public final class Bdd {
4443
private final int rootRef;
4544
private final int conditionCount;
4645
private final int resultCount;
46+
private final int nodeCount;
4747

4848
/**
4949
* Creates a BDD by streaming nodes directly into the structure.
@@ -58,6 +58,7 @@ public Bdd(int rootRef, int conditionCount, int resultCount, int nodeCount, Cons
5858
this.rootRef = rootRef;
5959
this.conditionCount = conditionCount;
6060
this.resultCount = resultCount;
61+
this.nodeCount = nodeCount;
6162

6263
if (rootRef < 0 && rootRef != -1) {
6364
throw new IllegalArgumentException("Root reference cannot be complemented: " + rootRef);
@@ -96,20 +97,21 @@ public void accept(int var, int high, int low) {
9697
}
9798
}
9899

99-
Bdd(int[] variables, int[] highs, int[] lows, int rootRef, int conditionCount, int resultCount) {
100+
Bdd(int[] variables, int[] highs, int[] lows, int nodeCount, int rootRef, int conditionCount, int resultCount) {
100101
this.variables = Objects.requireNonNull(variables, "variables is null");
101102
this.highs = Objects.requireNonNull(highs, "highs is null");
102103
this.lows = Objects.requireNonNull(lows, "lows is null");
103104
this.rootRef = rootRef;
104105
this.conditionCount = conditionCount;
105106
this.resultCount = resultCount;
107+
this.nodeCount = nodeCount;
106108

107109
if (rootRef < 0 && rootRef != -1) {
108110
throw new IllegalArgumentException("Root reference cannot be complemented: " + rootRef);
109-
}
110-
111-
if (variables.length != highs.length || variables.length != lows.length) {
111+
} else if (variables.length != highs.length || variables.length != lows.length) {
112112
throw new IllegalArgumentException("Array lengths must match");
113+
} else if (nodeCount > variables.length) {
114+
throw new IllegalArgumentException("Node count exceeds array capacity");
113115
}
114116
}
115117

@@ -137,7 +139,7 @@ public int getResultCount() {
137139
* @return the node count
138140
*/
139141
public int getNodeCount() {
140-
return variables.length;
142+
return nodeCount;
141143
}
142144

143145
/**
@@ -156,16 +158,24 @@ public int getRootRef() {
156158
* @return the variable index
157159
*/
158160
public int getVariable(int nodeIndex) {
161+
validateRange(nodeIndex);
159162
return variables[nodeIndex];
160163
}
161164

165+
private void validateRange(int index) {
166+
if (index < 0 || index >= nodeCount) {
167+
throw new IndexOutOfBoundsException("Node index out of bounds: " + index + " (size: " + nodeCount + ")");
168+
}
169+
}
170+
162171
/**
163172
* Gets the high (true) reference for a node.
164173
*
165174
* @param nodeIndex the node index (0-based)
166175
* @return the high reference
167176
*/
168177
public int getHigh(int nodeIndex) {
178+
validateRange(nodeIndex);
169179
return highs[nodeIndex];
170180
}
171181

@@ -176,6 +186,7 @@ public int getHigh(int nodeIndex) {
176186
* @return the low reference
177187
*/
178188
public int getLow(int nodeIndex) {
189+
validateRange(nodeIndex);
179190
return lows[nodeIndex];
180191
}
181192

@@ -185,7 +196,7 @@ public int getLow(int nodeIndex) {
185196
* @param consumer the consumer to receive the integers
186197
*/
187198
public void getNodes(BddNodeConsumer consumer) {
188-
for (int i = 0; i < variables.length; i++) {
199+
for (int i = 0; i < nodeCount; i++) {
189200
consumer.accept(variables[i], highs[i], lows[i]);
190201
}
191202
}
@@ -264,21 +275,31 @@ public boolean equals(Object obj) {
264275
} else if (!(obj instanceof Bdd)) {
265276
return false;
266277
}
278+
267279
Bdd other = (Bdd) obj;
268-
return rootRef == other.rootRef
269-
&& conditionCount == other.conditionCount
270-
&& resultCount == other.resultCount
271-
&& Arrays.equals(variables, other.variables)
272-
&& Arrays.equals(highs, other.highs)
273-
&& Arrays.equals(lows, other.lows);
280+
if (rootRef != other.rootRef
281+
|| conditionCount != other.conditionCount
282+
|| resultCount != other.resultCount
283+
|| nodeCount != other.nodeCount) {
284+
return false;
285+
}
286+
287+
// Now check the views of arrays of each.
288+
for (int i = 0; i < nodeCount; i++) {
289+
if (variables[i] != other.variables[i] || highs[i] != other.highs[i] || lows[i] != other.lows[i]) {
290+
return false;
291+
}
292+
}
293+
294+
return true;
274295
}
275296

276297
@Override
277298
public int hashCode() {
278-
int hash = 31 * rootRef + variables.length;
299+
int hash = 31 * rootRef + nodeCount;
279300
// Sample up to 16 nodes distributed across the BDD
280-
int step = Math.max(1, variables.length / 16);
281-
for (int i = 0; i < variables.length; i += step) {
301+
int step = Math.max(1, nodeCount / 16);
302+
for (int i = 0; i < nodeCount; i += step) {
282303
hash = 31 * hash + variables[i];
283304
hash = 31 * hash + highs[i];
284305
hash = 31 * hash + lows[i];

smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ public BddBuilder() {
6363
lows[0] = FALSE_REF;
6464
}
6565

66+
int getNodeCount() {
67+
return nodeCount;
68+
}
69+
6670
/**
6771
* Sets the number of conditions. Must be called before creating result nodes.
6872
*
@@ -170,8 +174,8 @@ private int insertNode(int var, int high, int low, boolean flip) {
170174

171175
private void ensureCapacity() {
172176
if (nodeCount >= variables.length) {
173-
// Grow by 50%
174-
int newCapacity = variables.length + (variables.length >> 1);
177+
// Double the current capacity
178+
int newCapacity = variables.length * 2;
175179
variables = Arrays.copyOf(variables, newCapacity);
176180
highs = Arrays.copyOf(highs, newCapacity);
177181
lows = Arrays.copyOf(lows, newCapacity);
@@ -592,26 +596,11 @@ public BddBuilder reset() {
592596
return this;
593597
}
594598

595-
/**
596-
* Get the nodes as a flat array.
597-
*
598-
* @return array of nodes, trimmed to actual size.
599-
*/
600-
public int[] getNodesArray() {
601-
// Convert back to flat array for compatibility
602-
int[] result = new int[nodeCount * 3];
603-
for (int i = 0; i < nodeCount; i++) {
604-
int baseIdx = i * 3;
605-
result[baseIdx] = variables[i];
606-
result[baseIdx + 1] = highs[i];
607-
result[baseIdx + 2] = lows[i];
608-
}
609-
return result;
610-
}
611-
612599
/**
613600
* Builds a BDD from the current state of the builder.
614601
*
602+
* <p>The builder must be reset() before reuse after calling this method.
603+
*
615604
* @return a new BDD instance
616605
* @throws IllegalStateException if condition count has not been set
617606
*/
@@ -620,11 +609,7 @@ Bdd build(int rootRef, int resultCount) {
620609
throw new IllegalStateException("Condition count must be set before building BDD");
621610
}
622611

623-
// Create trimmed copies of the arrays with only the used portion
624-
int[] trimmedVariables = Arrays.copyOf(variables, nodeCount);
625-
int[] trimmedHighs = Arrays.copyOf(highs, nodeCount);
626-
int[] trimmedLows = Arrays.copyOf(lows, nodeCount);
627-
return new Bdd(trimmedVariables, trimmedHighs, trimmedLows, rootRef, conditionCount, resultCount);
612+
return new Bdd(variables, highs, lows, nodeCount, rootRef, conditionCount, resultCount);
628613
}
629614

630615
private void validateBooleanOperands(int f, int g, String operation) {

smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ Bdd compile() {
6262
noMatchIndex = getOrCreateResultIndex(NoMatchRule.INSTANCE);
6363
int rootRef = convertCfgToBdd(cfg.getRoot());
6464
rootRef = bddBuilder.reduce(rootRef);
65-
6665
Bdd bdd = bddBuilder.build(rootRef, indexedResults.size());
66+
6767
long elapsed = System.currentTimeMillis() - start;
6868
LOGGER.fine(String.format(
6969
"BDD compilation complete: %d conditions, %d results, %d BDD nodes in %dms",
@@ -75,6 +75,14 @@ Bdd compile() {
7575
return bdd;
7676
}
7777

78+
List<Rule> getIndexedResults() {
79+
return indexedResults;
80+
}
81+
82+
List<Condition> getOrderedConditions() {
83+
return orderedConditions;
84+
}
85+
7886
private int convertCfgToBdd(CfgNode cfgNode) {
7987
Integer cached = nodeCache.get(cfgNode);
8088
if (cached != null) {

smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java

Lines changed: 19 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111
import java.io.IOException;
1212
import java.io.UncheckedIOException;
1313
import java.util.ArrayList;
14-
import java.util.Arrays;
1514
import java.util.Base64;
16-
import java.util.LinkedHashSet;
1715
import java.util.List;
1816
import java.util.Set;
17+
import java.util.function.Function;
1918
import software.amazon.smithy.model.node.Node;
2019
import software.amazon.smithy.model.node.ObjectNode;
2120
import software.amazon.smithy.model.shapes.ShapeId;
@@ -27,9 +26,6 @@
2726
import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule;
2827
import software.amazon.smithy.rulesengine.language.syntax.rule.Rule;
2928
import software.amazon.smithy.rulesengine.logic.cfg.Cfg;
30-
import software.amazon.smithy.rulesengine.logic.cfg.CfgNode;
31-
import software.amazon.smithy.rulesengine.logic.cfg.ConditionData;
32-
import software.amazon.smithy.rulesengine.logic.cfg.ResultNode;
3329
import software.amazon.smithy.utils.SetUtils;
3430
import software.amazon.smithy.utils.SmithyBuilder;
3531
import software.amazon.smithy.utils.ToSmithyBuilder;
@@ -68,43 +64,19 @@ private BddTrait(Builder builder) {
6864
* @return the BddTrait containing the compiled BDD and all context
6965
*/
7066
public static BddTrait from(Cfg cfg) {
71-
ConditionData conditionData = cfg.getConditionData();
72-
List<Condition> conditions = Arrays.asList(conditionData.getConditions());
73-
74-
// Compile the BDD
7567
BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder());
7668
Bdd bdd = compiler.compile();
7769

78-
List<Rule> results = extractResultsFromCfg(cfg, bdd);
79-
Parameters parameters = cfg.getRuleSet().getParameters();
80-
return builder().parameters(parameters).conditions(conditions).results(results).bdd(bdd).build();
81-
}
82-
83-
private static List<Rule> extractResultsFromCfg(Cfg cfg, Bdd bdd) {
84-
// The BddCompiler always puts NoMatchRule at index 0
85-
List<Rule> results = new ArrayList<>();
86-
results.add(NoMatchRule.INSTANCE);
87-
88-
Set<Rule> uniqueResults = new LinkedHashSet<>();
89-
for (CfgNode node : cfg) {
90-
if (node instanceof ResultNode) {
91-
Rule result = ((ResultNode) node).getResult();
92-
if (result != null && !(result instanceof NoMatchRule)) {
93-
uniqueResults.add(result.withoutConditions());
94-
}
95-
}
96-
}
97-
98-
results.addAll(uniqueResults);
99-
100-
if (results.size() != bdd.getResultCount()) {
101-
throw new IllegalStateException(String.format(
102-
"Result count mismatch: found %d results in CFG but BDD expects %d",
103-
results.size(),
104-
bdd.getResultCount()));
70+
if (compiler.getOrderedConditions().size() != bdd.getConditionCount()) {
71+
throw new IllegalStateException("Mismatch between BDD var count and orderedConditions size");
10572
}
10673

107-
return results;
74+
return builder()
75+
.parameters(cfg.getRuleSet().getParameters())
76+
.conditions(compiler.getOrderedConditions())
77+
.results(compiler.getIndexedResults())
78+
.bdd(bdd)
79+
.build();
10880
}
10981

11082
/**
@@ -143,6 +115,16 @@ public Bdd getBdd() {
143115
return bdd;
144116
}
145117

118+
/**
119+
* Transform this BDD using the given function and return the updated BddTrait.
120+
*
121+
* @param transformer Transformer used to modify the trait.
122+
* @return the updated trait.
123+
*/
124+
public BddTrait transform(Function<BddTrait, BddTrait> transformer) {
125+
return transformer.apply(this);
126+
}
127+
146128
@Override
147129
protected Node createNode() {
148130
ObjectNode.Builder builder = ObjectNode.builder();

smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,24 @@
1313
* <p>This transformation reverses the node array (except the terminal at index 0)
1414
* and updates all references throughout the BDD to maintain correctness.
1515
*/
16-
public final class NodeReversal implements Function<Bdd, Bdd> {
16+
public final class NodeReversal implements Function<BddTrait, BddTrait> {
1717

1818
private static final Logger LOGGER = Logger.getLogger(NodeReversal.class.getName());
1919

2020
@Override
21-
public Bdd apply(Bdd bdd) {
21+
public BddTrait apply(BddTrait trait) {
22+
Bdd reversedBdd = reverse(trait.getBdd());
23+
// Only rebuild the trait if the BDD actually changed
24+
return reversedBdd == trait.getBdd() ? trait : trait.toBuilder().bdd(reversedBdd).build();
25+
}
26+
27+
/**
28+
* Reverses the node ordering in a BDD.
29+
*
30+
* @param bdd the BDD to reverse
31+
* @return the reversed BDD, or the original if too small to reverse
32+
*/
33+
public static Bdd reverse(Bdd bdd) {
2234
LOGGER.info("Starting BDD node reversal optimization");
2335
int nodeCount = bdd.getNodeCount();
2436

@@ -62,7 +74,7 @@ public Bdd apply(Bdd bdd) {
6274
* @param oldToNew the index mapping array
6375
* @return the remapped reference
6476
*/
65-
private int remapReference(int ref, int[] oldToNew) {
77+
private static int remapReference(int ref, int[] oldToNew) {
6678
// Return result references as-is.
6779
if (ref == 0) {
6880
return 0;

0 commit comments

Comments
 (0)