Skip to content

Commit c648fa5

Browse files
committed
Fix BddTrait logic issue, using wrong conditions
We were using the wrong condition ordering in BddTrait after compiling a Bdd from the CFG, leading to a totally broken BDD. Also adds some tests, fixes, and generalizes BddTrait transforms
1 parent 1487287 commit c648fa5

File tree

13 files changed

+697
-276
lines changed

13 files changed

+697
-276
lines changed

smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java

Lines changed: 85 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
import java.io.UncheckedIOException;
1111
import java.io.Writer;
1212
import java.nio.charset.StandardCharsets;
13-
import java.util.Arrays;
14-
import java.util.Objects;
1513
import java.util.function.Consumer;
1614
import software.amazon.smithy.rulesengine.logic.ConditionEvaluator;
1715

@@ -44,6 +42,7 @@ public final class Bdd {
4442
private final int rootRef;
4543
private final int conditionCount;
4644
private final int resultCount;
45+
private final int nodeCount;
4746

4847
/**
4948
* Creates a BDD by streaming nodes directly into the structure.
@@ -55,23 +54,68 @@ public final class Bdd {
5554
* @param nodeHandler a handler that will provide nodes via a consumer
5655
*/
5756
public Bdd(int rootRef, int conditionCount, int resultCount, int nodeCount, Consumer<BddNodeConsumer> nodeHandler) {
57+
validateCounts(conditionCount, resultCount, nodeCount);
58+
validateRootReference(rootRef, nodeCount);
59+
5860
this.rootRef = rootRef;
5961
this.conditionCount = conditionCount;
6062
this.resultCount = resultCount;
61-
62-
if (rootRef < 0 && rootRef != -1) {
63-
throw new IllegalArgumentException("Root reference cannot be complemented: " + rootRef);
64-
}
63+
this.nodeCount = nodeCount;
6564

6665
InputNodeConsumer consumer = new InputNodeConsumer(nodeCount);
6766
nodeHandler.accept(consumer);
68-
6967
this.variables = consumer.variables;
7068
this.highs = consumer.highs;
7169
this.lows = consumer.lows;
7270

7371
if (consumer.index != nodeCount) {
74-
throw new IllegalStateException("Expected " + nodeCount + " node, but got " + consumer.index);
72+
throw new IllegalStateException("Expected " + nodeCount + " nodes, but got " + consumer.index);
73+
}
74+
}
75+
76+
Bdd(int[] variables, int[] highs, int[] lows, int nodeCount, int rootRef, int conditionCount, int resultCount) {
77+
validateArrays(variables, highs, lows, nodeCount);
78+
validateCounts(conditionCount, resultCount, nodeCount);
79+
validateRootReference(rootRef, nodeCount);
80+
81+
this.variables = variables;
82+
this.highs = highs;
83+
this.lows = lows;
84+
this.rootRef = rootRef;
85+
this.conditionCount = conditionCount;
86+
this.resultCount = resultCount;
87+
this.nodeCount = nodeCount;
88+
}
89+
90+
private static void validateCounts(int conditionCount, int resultCount, int nodeCount) {
91+
if (conditionCount < 0) {
92+
throw new IllegalArgumentException("Condition count cannot be negative: " + conditionCount);
93+
} else if (resultCount < 0) {
94+
throw new IllegalArgumentException("Result count cannot be negative: " + resultCount);
95+
} else if (nodeCount < 0) {
96+
throw new IllegalArgumentException("Node count cannot be negative: " + nodeCount);
97+
}
98+
}
99+
100+
private static void validateRootReference(int rootRef, int nodeCount) {
101+
if (isComplemented(rootRef) && !isTerminal(rootRef)) {
102+
throw new IllegalArgumentException("Root reference cannot be complemented: " + rootRef);
103+
} else if (isNodeReference(rootRef)) {
104+
int idx = Math.abs(rootRef) - 1;
105+
if (idx >= nodeCount) {
106+
throw new IllegalArgumentException("Root points to invalid BDD node: " + idx +
107+
" (node count: " + nodeCount + ")");
108+
}
109+
}
110+
}
111+
112+
private static void validateArrays(int[] variables, int[] highs, int[] lows, int nodeCount) {
113+
if (variables.length != highs.length || variables.length != lows.length) {
114+
throw new IllegalArgumentException("Array lengths must match: variables=" + variables.length +
115+
", highs=" + highs.length + ", lows=" + lows.length);
116+
} else if (nodeCount > variables.length) {
117+
throw new IllegalArgumentException("Node count (" + nodeCount +
118+
") exceeds array capacity (" + variables.length + ")");
75119
}
76120
}
77121

@@ -96,23 +140,6 @@ public void accept(int var, int high, int low) {
96140
}
97141
}
98142

99-
Bdd(int[] variables, int[] highs, int[] lows, int rootRef, int conditionCount, int resultCount) {
100-
this.variables = Objects.requireNonNull(variables, "variables is null");
101-
this.highs = Objects.requireNonNull(highs, "highs is null");
102-
this.lows = Objects.requireNonNull(lows, "lows is null");
103-
this.rootRef = rootRef;
104-
this.conditionCount = conditionCount;
105-
this.resultCount = resultCount;
106-
107-
if (rootRef < 0 && rootRef != -1) {
108-
throw new IllegalArgumentException("Root reference cannot be complemented: " + rootRef);
109-
}
110-
111-
if (variables.length != highs.length || variables.length != lows.length) {
112-
throw new IllegalArgumentException("Array lengths must match");
113-
}
114-
}
115-
116143
/**
117144
* Gets the number of conditions.
118145
*
@@ -137,7 +164,7 @@ public int getResultCount() {
137164
* @return the node count
138165
*/
139166
public int getNodeCount() {
140-
return variables.length;
167+
return nodeCount;
141168
}
142169

143170
/**
@@ -156,16 +183,24 @@ public int getRootRef() {
156183
* @return the variable index
157184
*/
158185
public int getVariable(int nodeIndex) {
186+
validateRange(nodeIndex);
159187
return variables[nodeIndex];
160188
}
161189

190+
private void validateRange(int index) {
191+
if (index < 0 || index >= nodeCount) {
192+
throw new IndexOutOfBoundsException("Node index out of bounds: " + index + " (size: " + nodeCount + ")");
193+
}
194+
}
195+
162196
/**
163197
* Gets the high (true) reference for a node.
164198
*
165199
* @param nodeIndex the node index (0-based)
166200
* @return the high reference
167201
*/
168202
public int getHigh(int nodeIndex) {
203+
validateRange(nodeIndex);
169204
return highs[nodeIndex];
170205
}
171206

@@ -176,6 +211,7 @@ public int getHigh(int nodeIndex) {
176211
* @return the low reference
177212
*/
178213
public int getLow(int nodeIndex) {
214+
validateRange(nodeIndex);
179215
return lows[nodeIndex];
180216
}
181217

@@ -185,7 +221,7 @@ public int getLow(int nodeIndex) {
185221
* @param consumer the consumer to receive the integers
186222
*/
187223
public void getNodes(BddNodeConsumer consumer) {
188-
for (int i = 0; i < variables.length; i++) {
224+
for (int i = 0; i < nodeCount; i++) {
189225
consumer.accept(variables[i], highs[i], lows[i]);
190226
}
191227
}
@@ -201,17 +237,14 @@ public int evaluate(ConditionEvaluator ev) {
201237
int[] vars = this.variables;
202238
int[] hi = this.highs;
203239
int[] lo = this.lows;
204-
int off = RESULT_OFFSET;
205240

206-
// keep walking while ref is a non-terminal node
207-
while ((ref > 1 && ref < off) || (ref < -1 && ref > -off)) {
241+
while (isNodeReference(ref)) {
208242
int idx = ref > 0 ? ref - 1 : -ref - 1; // Math.abs
209243
// test ^ complement, pick hi or lo
210244
ref = (ev.test(vars[idx]) ^ (ref < 0)) ? hi[idx] : lo[idx];
211245
}
212246

213-
// +1/-1 => no match
214-
return (ref == 1 || ref == -1) ? -1 : (ref - off);
247+
return isTerminal(ref) ? -1 : ref - RESULT_OFFSET;
215248
}
216249

217250
/**
@@ -221,10 +254,7 @@ public int evaluate(ConditionEvaluator ev) {
221254
* @return true if this is a node reference
222255
*/
223256
public static boolean isNodeReference(int ref) {
224-
if (ref == 0 || isTerminal(ref)) {
225-
return false;
226-
}
227-
return Math.abs(ref) < RESULT_OFFSET;
257+
return (ref > 1 && ref < RESULT_OFFSET) || (ref < -1 && ref > -RESULT_OFFSET);
228258
}
229259

230260
/**
@@ -264,21 +294,31 @@ public boolean equals(Object obj) {
264294
} else if (!(obj instanceof Bdd)) {
265295
return false;
266296
}
297+
267298
Bdd other = (Bdd) obj;
268-
return rootRef == other.rootRef
269-
&& conditionCount == other.conditionCount
270-
&& resultCount == other.resultCount
271-
&& Arrays.equals(variables, other.variables)
272-
&& Arrays.equals(highs, other.highs)
273-
&& Arrays.equals(lows, other.lows);
299+
if (rootRef != other.rootRef
300+
|| conditionCount != other.conditionCount
301+
|| resultCount != other.resultCount
302+
|| nodeCount != other.nodeCount) {
303+
return false;
304+
}
305+
306+
// Now check the views of arrays of each.
307+
for (int i = 0; i < nodeCount; i++) {
308+
if (variables[i] != other.variables[i] || highs[i] != other.highs[i] || lows[i] != other.lows[i]) {
309+
return false;
310+
}
311+
}
312+
313+
return true;
274314
}
275315

276316
@Override
277317
public int hashCode() {
278-
int hash = 31 * rootRef + variables.length;
318+
int hash = 31 * rootRef + nodeCount;
279319
// Sample up to 16 nodes distributed across the BDD
280-
int step = Math.max(1, variables.length / 16);
281-
for (int i = 0; i < variables.length; i += step) {
320+
int step = Math.max(1, nodeCount / 16);
321+
for (int i = 0; i < nodeCount; i += step) {
282322
hash = 31 * hash + variables[i];
283323
hash = 31 * hash + highs[i];
284324
hash = 31 * hash + lows[i];

smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ public BddBuilder() {
6363
lows[0] = FALSE_REF;
6464
}
6565

66+
int getNodeCount() {
67+
return nodeCount;
68+
}
69+
6670
/**
6771
* Sets the number of conditions. Must be called before creating result nodes.
6872
*
@@ -170,8 +174,8 @@ private int insertNode(int var, int high, int low, boolean flip) {
170174

171175
private void ensureCapacity() {
172176
if (nodeCount >= variables.length) {
173-
// Grow by 50%
174-
int newCapacity = variables.length + (variables.length >> 1);
177+
// Double the current capacity
178+
int newCapacity = variables.length * 2;
175179
variables = Arrays.copyOf(variables, newCapacity);
176180
highs = Arrays.copyOf(highs, newCapacity);
177181
lows = Arrays.copyOf(lows, newCapacity);
@@ -592,26 +596,11 @@ public BddBuilder reset() {
592596
return this;
593597
}
594598

595-
/**
596-
* Get the nodes as a flat array.
597-
*
598-
* @return array of nodes, trimmed to actual size.
599-
*/
600-
public int[] getNodesArray() {
601-
// Convert back to flat array for compatibility
602-
int[] result = new int[nodeCount * 3];
603-
for (int i = 0; i < nodeCount; i++) {
604-
int baseIdx = i * 3;
605-
result[baseIdx] = variables[i];
606-
result[baseIdx + 1] = highs[i];
607-
result[baseIdx + 2] = lows[i];
608-
}
609-
return result;
610-
}
611-
612599
/**
613600
* Builds a BDD from the current state of the builder.
614601
*
602+
* <p>The builder must be reset() before reuse after calling this method.
603+
*
615604
* @return a new BDD instance
616605
* @throws IllegalStateException if condition count has not been set
617606
*/
@@ -620,11 +609,7 @@ Bdd build(int rootRef, int resultCount) {
620609
throw new IllegalStateException("Condition count must be set before building BDD");
621610
}
622611

623-
// Create trimmed copies of the arrays with only the used portion
624-
int[] trimmedVariables = Arrays.copyOf(variables, nodeCount);
625-
int[] trimmedHighs = Arrays.copyOf(highs, nodeCount);
626-
int[] trimmedLows = Arrays.copyOf(lows, nodeCount);
627-
return new Bdd(trimmedVariables, trimmedHighs, trimmedLows, rootRef, conditionCount, resultCount);
612+
return new Bdd(variables, highs, lows, nodeCount, rootRef, conditionCount, resultCount);
628613
}
629614

630615
private void validateBooleanOperands(int f, int g, String operation) {

smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ Bdd compile() {
6262
noMatchIndex = getOrCreateResultIndex(NoMatchRule.INSTANCE);
6363
int rootRef = convertCfgToBdd(cfg.getRoot());
6464
rootRef = bddBuilder.reduce(rootRef);
65-
6665
Bdd bdd = bddBuilder.build(rootRef, indexedResults.size());
66+
6767
long elapsed = System.currentTimeMillis() - start;
6868
LOGGER.fine(String.format(
6969
"BDD compilation complete: %d conditions, %d results, %d BDD nodes in %dms",
@@ -75,6 +75,14 @@ Bdd compile() {
7575
return bdd;
7676
}
7777

78+
List<Rule> getIndexedResults() {
79+
return indexedResults;
80+
}
81+
82+
List<Condition> getOrderedConditions() {
83+
return orderedConditions;
84+
}
85+
7886
private int convertCfgToBdd(CfgNode cfgNode) {
7987
Integer cached = nodeCache.get(cfgNode);
8088
if (cached != null) {

0 commit comments

Comments
 (0)