1010import java .io .UncheckedIOException ;
1111import java .io .Writer ;
1212import java .nio .charset .StandardCharsets ;
13- import java .util .Arrays ;
14- import java .util .Objects ;
1513import java .util .function .Consumer ;
1614import software .amazon .smithy .rulesengine .logic .ConditionEvaluator ;
1715
@@ -44,6 +42,7 @@ public final class Bdd {
4442 private final int rootRef ;
4543 private final int conditionCount ;
4644 private final int resultCount ;
45+ private final int nodeCount ;
4746
4847 /**
4948 * Creates a BDD by streaming nodes directly into the structure.
@@ -55,23 +54,68 @@ public final class Bdd {
5554 * @param nodeHandler a handler that will provide nodes via a consumer
5655 */
5756 public Bdd (int rootRef , int conditionCount , int resultCount , int nodeCount , Consumer <BddNodeConsumer > nodeHandler ) {
57+ validateCounts (conditionCount , resultCount , nodeCount );
58+ validateRootReference (rootRef , nodeCount );
59+
5860 this .rootRef = rootRef ;
5961 this .conditionCount = conditionCount ;
6062 this .resultCount = resultCount ;
61-
62- if (rootRef < 0 && rootRef != -1 ) {
63- throw new IllegalArgumentException ("Root reference cannot be complemented: " + rootRef );
64- }
63+ this .nodeCount = nodeCount ;
6564
6665 InputNodeConsumer consumer = new InputNodeConsumer (nodeCount );
6766 nodeHandler .accept (consumer );
68-
6967 this .variables = consumer .variables ;
7068 this .highs = consumer .highs ;
7169 this .lows = consumer .lows ;
7270
7371 if (consumer .index != nodeCount ) {
74- throw new IllegalStateException ("Expected " + nodeCount + " node, but got " + consumer .index );
72+ throw new IllegalStateException ("Expected " + nodeCount + " nodes, but got " + consumer .index );
73+ }
74+ }
75+
76+ Bdd (int [] variables , int [] highs , int [] lows , int nodeCount , int rootRef , int conditionCount , int resultCount ) {
77+ validateArrays (variables , highs , lows , nodeCount );
78+ validateCounts (conditionCount , resultCount , nodeCount );
79+ validateRootReference (rootRef , nodeCount );
80+
81+ this .variables = variables ;
82+ this .highs = highs ;
83+ this .lows = lows ;
84+ this .rootRef = rootRef ;
85+ this .conditionCount = conditionCount ;
86+ this .resultCount = resultCount ;
87+ this .nodeCount = nodeCount ;
88+ }
89+
90+ private static void validateCounts (int conditionCount , int resultCount , int nodeCount ) {
91+ if (conditionCount < 0 ) {
92+ throw new IllegalArgumentException ("Condition count cannot be negative: " + conditionCount );
93+ } else if (resultCount < 0 ) {
94+ throw new IllegalArgumentException ("Result count cannot be negative: " + resultCount );
95+ } else if (nodeCount < 0 ) {
96+ throw new IllegalArgumentException ("Node count cannot be negative: " + nodeCount );
97+ }
98+ }
99+
100+ private static void validateRootReference (int rootRef , int nodeCount ) {
101+ if (isComplemented (rootRef ) && !isTerminal (rootRef )) {
102+ throw new IllegalArgumentException ("Root reference cannot be complemented: " + rootRef );
103+ } else if (isNodeReference (rootRef )) {
104+ int idx = Math .abs (rootRef ) - 1 ;
105+ if (idx >= nodeCount ) {
106+ throw new IllegalArgumentException ("Root points to invalid BDD node: " + idx +
107+ " (node count: " + nodeCount + ")" );
108+ }
109+ }
110+ }
111+
112+ private static void validateArrays (int [] variables , int [] highs , int [] lows , int nodeCount ) {
113+ if (variables .length != highs .length || variables .length != lows .length ) {
114+ throw new IllegalArgumentException ("Array lengths must match: variables=" + variables .length +
115+ ", highs=" + highs .length + ", lows=" + lows .length );
116+ } else if (nodeCount > variables .length ) {
117+ throw new IllegalArgumentException ("Node count (" + nodeCount +
118+ ") exceeds array capacity (" + variables .length + ")" );
75119 }
76120 }
77121
@@ -96,23 +140,6 @@ public void accept(int var, int high, int low) {
96140 }
97141 }
98142
99- Bdd (int [] variables , int [] highs , int [] lows , int rootRef , int conditionCount , int resultCount ) {
100- this .variables = Objects .requireNonNull (variables , "variables is null" );
101- this .highs = Objects .requireNonNull (highs , "highs is null" );
102- this .lows = Objects .requireNonNull (lows , "lows is null" );
103- this .rootRef = rootRef ;
104- this .conditionCount = conditionCount ;
105- this .resultCount = resultCount ;
106-
107- if (rootRef < 0 && rootRef != -1 ) {
108- throw new IllegalArgumentException ("Root reference cannot be complemented: " + rootRef );
109- }
110-
111- if (variables .length != highs .length || variables .length != lows .length ) {
112- throw new IllegalArgumentException ("Array lengths must match" );
113- }
114- }
115-
116143 /**
117144 * Gets the number of conditions.
118145 *
@@ -137,7 +164,7 @@ public int getResultCount() {
137164 * @return the node count
138165 */
139166 public int getNodeCount () {
140- return variables . length ;
167+ return nodeCount ;
141168 }
142169
143170 /**
@@ -156,16 +183,24 @@ public int getRootRef() {
156183 * @return the variable index
157184 */
158185 public int getVariable (int nodeIndex ) {
186+ validateRange (nodeIndex );
159187 return variables [nodeIndex ];
160188 }
161189
190+ private void validateRange (int index ) {
191+ if (index < 0 || index >= nodeCount ) {
192+ throw new IndexOutOfBoundsException ("Node index out of bounds: " + index + " (size: " + nodeCount + ")" );
193+ }
194+ }
195+
162196 /**
163197 * Gets the high (true) reference for a node.
164198 *
165199 * @param nodeIndex the node index (0-based)
166200 * @return the high reference
167201 */
168202 public int getHigh (int nodeIndex ) {
203+ validateRange (nodeIndex );
169204 return highs [nodeIndex ];
170205 }
171206
@@ -176,6 +211,7 @@ public int getHigh(int nodeIndex) {
176211 * @return the low reference
177212 */
178213 public int getLow (int nodeIndex ) {
214+ validateRange (nodeIndex );
179215 return lows [nodeIndex ];
180216 }
181217
@@ -185,7 +221,7 @@ public int getLow(int nodeIndex) {
185221 * @param consumer the consumer to receive the integers
186222 */
187223 public void getNodes (BddNodeConsumer consumer ) {
188- for (int i = 0 ; i < variables . length ; i ++) {
224+ for (int i = 0 ; i < nodeCount ; i ++) {
189225 consumer .accept (variables [i ], highs [i ], lows [i ]);
190226 }
191227 }
@@ -201,17 +237,14 @@ public int evaluate(ConditionEvaluator ev) {
201237 int [] vars = this .variables ;
202238 int [] hi = this .highs ;
203239 int [] lo = this .lows ;
204- int off = RESULT_OFFSET ;
205240
206- // keep walking while ref is a non-terminal node
207- while ((ref > 1 && ref < off ) || (ref < -1 && ref > -off )) {
241+ while (isNodeReference (ref )) {
208242 int idx = ref > 0 ? ref - 1 : -ref - 1 ; // Math.abs
209243 // test ^ complement, pick hi or lo
210244 ref = (ev .test (vars [idx ]) ^ (ref < 0 )) ? hi [idx ] : lo [idx ];
211245 }
212246
213- // +1/-1 => no match
214- return (ref == 1 || ref == -1 ) ? -1 : (ref - off );
247+ return isTerminal (ref ) ? -1 : ref - RESULT_OFFSET ;
215248 }
216249
217250 /**
@@ -221,10 +254,7 @@ public int evaluate(ConditionEvaluator ev) {
221254 * @return true if this is a node reference
222255 */
223256 public static boolean isNodeReference (int ref ) {
224- if (ref == 0 || isTerminal (ref )) {
225- return false ;
226- }
227- return Math .abs (ref ) < RESULT_OFFSET ;
257+ return (ref > 1 && ref < RESULT_OFFSET ) || (ref < -1 && ref > -RESULT_OFFSET );
228258 }
229259
230260 /**
@@ -264,21 +294,31 @@ public boolean equals(Object obj) {
264294 } else if (!(obj instanceof Bdd )) {
265295 return false ;
266296 }
297+
267298 Bdd other = (Bdd ) obj ;
268- return rootRef == other .rootRef
269- && conditionCount == other .conditionCount
270- && resultCount == other .resultCount
271- && Arrays .equals (variables , other .variables )
272- && Arrays .equals (highs , other .highs )
273- && Arrays .equals (lows , other .lows );
299+ if (rootRef != other .rootRef
300+ || conditionCount != other .conditionCount
301+ || resultCount != other .resultCount
302+ || nodeCount != other .nodeCount ) {
303+ return false ;
304+ }
305+
306+ // Now check the views of arrays of each.
307+ for (int i = 0 ; i < nodeCount ; i ++) {
308+ if (variables [i ] != other .variables [i ] || highs [i ] != other .highs [i ] || lows [i ] != other .lows [i ]) {
309+ return false ;
310+ }
311+ }
312+
313+ return true ;
274314 }
275315
276316 @ Override
277317 public int hashCode () {
278- int hash = 31 * rootRef + variables . length ;
318+ int hash = 31 * rootRef + nodeCount ;
279319 // Sample up to 16 nodes distributed across the BDD
280- int step = Math .max (1 , variables . length / 16 );
281- for (int i = 0 ; i < variables . length ; i += step ) {
320+ int step = Math .max (1 , nodeCount / 16 );
321+ for (int i = 0 ; i < nodeCount ; i += step ) {
282322 hash = 31 * hash + variables [i ];
283323 hash = 31 * hash + highs [i ];
284324 hash = 31 * hash + lows [i ];
0 commit comments