10
10
import java .io .UncheckedIOException ;
11
11
import java .io .Writer ;
12
12
import java .nio .charset .StandardCharsets ;
13
- import java .util .Arrays ;
14
- import java .util .Objects ;
15
13
import java .util .function .Consumer ;
16
14
import software .amazon .smithy .rulesengine .logic .ConditionEvaluator ;
17
15
@@ -44,6 +42,7 @@ public final class Bdd {
44
42
private final int rootRef ;
45
43
private final int conditionCount ;
46
44
private final int resultCount ;
45
+ private final int nodeCount ;
47
46
48
47
/**
49
48
* Creates a BDD by streaming nodes directly into the structure.
@@ -55,23 +54,68 @@ public final class Bdd {
55
54
* @param nodeHandler a handler that will provide nodes via a consumer
56
55
*/
57
56
public Bdd (int rootRef , int conditionCount , int resultCount , int nodeCount , Consumer <BddNodeConsumer > nodeHandler ) {
57
+ validateCounts (conditionCount , resultCount , nodeCount );
58
+ validateRootReference (rootRef , nodeCount );
59
+
58
60
this .rootRef = rootRef ;
59
61
this .conditionCount = conditionCount ;
60
62
this .resultCount = resultCount ;
61
-
62
- if (rootRef < 0 && rootRef != -1 ) {
63
- throw new IllegalArgumentException ("Root reference cannot be complemented: " + rootRef );
64
- }
63
+ this .nodeCount = nodeCount ;
65
64
66
65
InputNodeConsumer consumer = new InputNodeConsumer (nodeCount );
67
66
nodeHandler .accept (consumer );
68
-
69
67
this .variables = consumer .variables ;
70
68
this .highs = consumer .highs ;
71
69
this .lows = consumer .lows ;
72
70
73
71
if (consumer .index != nodeCount ) {
74
- throw new IllegalStateException ("Expected " + nodeCount + " node, but got " + consumer .index );
72
+ throw new IllegalStateException ("Expected " + nodeCount + " nodes, but got " + consumer .index );
73
+ }
74
+ }
75
+
76
+ Bdd (int [] variables , int [] highs , int [] lows , int nodeCount , int rootRef , int conditionCount , int resultCount ) {
77
+ validateArrays (variables , highs , lows , nodeCount );
78
+ validateCounts (conditionCount , resultCount , nodeCount );
79
+ validateRootReference (rootRef , nodeCount );
80
+
81
+ this .variables = variables ;
82
+ this .highs = highs ;
83
+ this .lows = lows ;
84
+ this .rootRef = rootRef ;
85
+ this .conditionCount = conditionCount ;
86
+ this .resultCount = resultCount ;
87
+ this .nodeCount = nodeCount ;
88
+ }
89
+
90
+ private static void validateCounts (int conditionCount , int resultCount , int nodeCount ) {
91
+ if (conditionCount < 0 ) {
92
+ throw new IllegalArgumentException ("Condition count cannot be negative: " + conditionCount );
93
+ } else if (resultCount < 0 ) {
94
+ throw new IllegalArgumentException ("Result count cannot be negative: " + resultCount );
95
+ } else if (nodeCount < 0 ) {
96
+ throw new IllegalArgumentException ("Node count cannot be negative: " + nodeCount );
97
+ }
98
+ }
99
+
100
+ private static void validateRootReference (int rootRef , int nodeCount ) {
101
+ if (isComplemented (rootRef ) && !isTerminal (rootRef )) {
102
+ throw new IllegalArgumentException ("Root reference cannot be complemented: " + rootRef );
103
+ } else if (isNodeReference (rootRef )) {
104
+ int idx = Math .abs (rootRef ) - 1 ;
105
+ if (idx >= nodeCount ) {
106
+ throw new IllegalArgumentException ("Root points to invalid BDD node: " + idx +
107
+ " (node count: " + nodeCount + ")" );
108
+ }
109
+ }
110
+ }
111
+
112
+ private static void validateArrays (int [] variables , int [] highs , int [] lows , int nodeCount ) {
113
+ if (variables .length != highs .length || variables .length != lows .length ) {
114
+ throw new IllegalArgumentException ("Array lengths must match: variables=" + variables .length +
115
+ ", highs=" + highs .length + ", lows=" + lows .length );
116
+ } else if (nodeCount > variables .length ) {
117
+ throw new IllegalArgumentException ("Node count (" + nodeCount +
118
+ ") exceeds array capacity (" + variables .length + ")" );
75
119
}
76
120
}
77
121
@@ -96,23 +140,6 @@ public void accept(int var, int high, int low) {
96
140
}
97
141
}
98
142
99
- Bdd (int [] variables , int [] highs , int [] lows , int rootRef , int conditionCount , int resultCount ) {
100
- this .variables = Objects .requireNonNull (variables , "variables is null" );
101
- this .highs = Objects .requireNonNull (highs , "highs is null" );
102
- this .lows = Objects .requireNonNull (lows , "lows is null" );
103
- this .rootRef = rootRef ;
104
- this .conditionCount = conditionCount ;
105
- this .resultCount = resultCount ;
106
-
107
- if (rootRef < 0 && rootRef != -1 ) {
108
- throw new IllegalArgumentException ("Root reference cannot be complemented: " + rootRef );
109
- }
110
-
111
- if (variables .length != highs .length || variables .length != lows .length ) {
112
- throw new IllegalArgumentException ("Array lengths must match" );
113
- }
114
- }
115
-
116
143
/**
117
144
* Gets the number of conditions.
118
145
*
@@ -137,7 +164,7 @@ public int getResultCount() {
137
164
* @return the node count
138
165
*/
139
166
public int getNodeCount () {
140
- return variables . length ;
167
+ return nodeCount ;
141
168
}
142
169
143
170
/**
@@ -156,16 +183,24 @@ public int getRootRef() {
156
183
* @return the variable index
157
184
*/
158
185
public int getVariable (int nodeIndex ) {
186
+ validateRange (nodeIndex );
159
187
return variables [nodeIndex ];
160
188
}
161
189
190
+ private void validateRange (int index ) {
191
+ if (index < 0 || index >= nodeCount ) {
192
+ throw new IndexOutOfBoundsException ("Node index out of bounds: " + index + " (size: " + nodeCount + ")" );
193
+ }
194
+ }
195
+
162
196
/**
163
197
* Gets the high (true) reference for a node.
164
198
*
165
199
* @param nodeIndex the node index (0-based)
166
200
* @return the high reference
167
201
*/
168
202
public int getHigh (int nodeIndex ) {
203
+ validateRange (nodeIndex );
169
204
return highs [nodeIndex ];
170
205
}
171
206
@@ -176,6 +211,7 @@ public int getHigh(int nodeIndex) {
176
211
* @return the low reference
177
212
*/
178
213
public int getLow (int nodeIndex ) {
214
+ validateRange (nodeIndex );
179
215
return lows [nodeIndex ];
180
216
}
181
217
@@ -185,7 +221,7 @@ public int getLow(int nodeIndex) {
185
221
* @param consumer the consumer to receive the integers
186
222
*/
187
223
public void getNodes (BddNodeConsumer consumer ) {
188
- for (int i = 0 ; i < variables . length ; i ++) {
224
+ for (int i = 0 ; i < nodeCount ; i ++) {
189
225
consumer .accept (variables [i ], highs [i ], lows [i ]);
190
226
}
191
227
}
@@ -201,17 +237,14 @@ public int evaluate(ConditionEvaluator ev) {
201
237
int [] vars = this .variables ;
202
238
int [] hi = this .highs ;
203
239
int [] lo = this .lows ;
204
- int off = RESULT_OFFSET ;
205
240
206
- // keep walking while ref is a non-terminal node
207
- while ((ref > 1 && ref < off ) || (ref < -1 && ref > -off )) {
241
+ while (isNodeReference (ref )) {
208
242
int idx = ref > 0 ? ref - 1 : -ref - 1 ; // Math.abs
209
243
// test ^ complement, pick hi or lo
210
244
ref = (ev .test (vars [idx ]) ^ (ref < 0 )) ? hi [idx ] : lo [idx ];
211
245
}
212
246
213
- // +1/-1 => no match
214
- return (ref == 1 || ref == -1 ) ? -1 : (ref - off );
247
+ return isTerminal (ref ) ? -1 : ref - RESULT_OFFSET ;
215
248
}
216
249
217
250
/**
@@ -221,10 +254,7 @@ public int evaluate(ConditionEvaluator ev) {
221
254
* @return true if this is a node reference
222
255
*/
223
256
public static boolean isNodeReference (int ref ) {
224
- if (ref == 0 || isTerminal (ref )) {
225
- return false ;
226
- }
227
- return Math .abs (ref ) < RESULT_OFFSET ;
257
+ return (ref > 1 && ref < RESULT_OFFSET ) || (ref < -1 && ref > -RESULT_OFFSET );
228
258
}
229
259
230
260
/**
@@ -264,21 +294,31 @@ public boolean equals(Object obj) {
264
294
} else if (!(obj instanceof Bdd )) {
265
295
return false ;
266
296
}
297
+
267
298
Bdd other = (Bdd ) obj ;
268
- return rootRef == other .rootRef
269
- && conditionCount == other .conditionCount
270
- && resultCount == other .resultCount
271
- && Arrays .equals (variables , other .variables )
272
- && Arrays .equals (highs , other .highs )
273
- && Arrays .equals (lows , other .lows );
299
+ if (rootRef != other .rootRef
300
+ || conditionCount != other .conditionCount
301
+ || resultCount != other .resultCount
302
+ || nodeCount != other .nodeCount ) {
303
+ return false ;
304
+ }
305
+
306
+ // Now check the views of arrays of each.
307
+ for (int i = 0 ; i < nodeCount ; i ++) {
308
+ if (variables [i ] != other .variables [i ] || highs [i ] != other .highs [i ] || lows [i ] != other .lows [i ]) {
309
+ return false ;
310
+ }
311
+ }
312
+
313
+ return true ;
274
314
}
275
315
276
316
@ Override
277
317
public int hashCode () {
278
- int hash = 31 * rootRef + variables . length ;
318
+ int hash = 31 * rootRef + nodeCount ;
279
319
// Sample up to 16 nodes distributed across the BDD
280
- int step = Math .max (1 , variables . length / 16 );
281
- for (int i = 0 ; i < variables . length ; i += step ) {
320
+ int step = Math .max (1 , nodeCount / 16 );
321
+ for (int i = 0 ; i < nodeCount ; i += step ) {
282
322
hash = 31 * hash + variables [i ];
283
323
hash = 31 * hash + highs [i ];
284
324
hash = 31 * hash + lows [i ];
0 commit comments