diff --git a/config/spotbugs/filter.xml b/config/spotbugs/filter.xml index 6188a7e3059..a92ee96eade 100644 --- a/config/spotbugs/filter.xml +++ b/config/spotbugs/filter.xml @@ -218,4 +218,10 @@ + + + + + + diff --git a/docs/source-2.0/additional-specs/rules-engine/specification.rst b/docs/source-2.0/additional-specs/rules-engine/specification.rst index f1861f97689..01ce02bef77 100644 --- a/docs/source-2.0/additional-specs/rules-engine/specification.rst +++ b/docs/source-2.0/additional-specs/rules-engine/specification.rst @@ -14,18 +14,27 @@ are composed of a set of *conditions*, which determine if a rule should be selected, and a result. Conditions act on the defined parameters, and allow for the modeling of statements. -When a rule’s conditions are evaluated successfully, the rule provides either a +When a rule's conditions are evaluated successfully, the rule provides either a result and its accompanying requirements or an error describing the unsupported state. Modeled endpoint errors allow for more explicit descriptions to users, such as providing errors when a service doesn't support a combination of conditions. +-------------------- +Rules engine version +-------------------- + +The rules engine specification is versioned, with the current version being 1.1. Unless otherwise specified, functions, +features, and built-ins have been available since version 1.0. Any feature, function, or built-in used in the +``endpointRuleSet`` or ``endpointBdd`` traits MUST be supported by the declared version of the trait. In other words, +the feature's introduction version must be less than or equal to the trait version. .. smithy-trait:: smithy.rules#endpointRuleSet .. _smithy.rules#endpointRuleSet-trait: +-------------------------------------- ``smithy.rules#endpointRuleSet`` trait -====================================== +-------------------------------------- Summary Defines a rule set for deriving service endpoints at runtime. @@ -45,8 +54,7 @@ The content of the ``endpointRuleSet`` document has the following properties: - Description * - version - ``string`` - - **Required**. The rule set schema version. This specification covers - version 1.0 of the endpoint rule set. + - **Required**. The rules engine version (e.g., 1.0). * - serviceId - ``string`` - **Required**. An identifier for the corresponding service. @@ -74,6 +82,184 @@ or :ref:`error rules, ` with an empty set of conditions to provide a more meaningful default or error depending on the scenario. +.. smithy-trait:: smithy.rules#endpointBdd +.. _smithy.rules#endpointBdd-trait: + +---------------------------------- +``smithy.rules#endpointBdd`` trait +---------------------------------- + +.. warning:: Experimental + + This trait is experimental and subject to change. + +Summary + A Binary `Decision Diagram (BDD) `_ representation of + endpoint rules that is more compact and efficient at runtime than the decision-tree-based EndpointRuleSet trait. +Trait selector + ``service`` +Value type + ``structure`` + +The ``endpointBdd`` trait provides a BDD representation of endpoint rules, optimizing runtime evaluation by +eliminating redundant condition evaluations and reducing the decision tree to a minimal directed acyclic graph. +This trait is an alternative to ``endpointRuleSet`` that trades compile-time complexity for significantly improved +runtime performance and reduced artifact sizes. + +.. note:: + + The ``endpointBdd`` trait can be generated from an ``endpointRuleSet`` trait through compilation. Services may + provide either trait, with ``endpointBdd`` preferred for production use due to its performance characteristics. + +The ``endpointBdd`` structure has the following properties: + +.. list-table:: + :header-rows: 1 + :widths: 10 30 60 + + * - Property name + - Type + - Description + * - version + - ``string`` + - **Required**. The endpoint rules engine version. Must be at least version 1.1. + * - parameters + - ``map`` of `Parameter object`_ + - **Required**. A map of zero or more endpoint parameter names to + their parameter configuration. Uses the same parameter structure as + ``endpointRuleSet``. + * - conditions + - ``array`` of `Condition object`_ + - **Required**. Array of conditions that are evaluated during BDD + traversal. Each condition is referenced by its index in this array. + * - results + - ``array`` of `Endpoint rule object`_ or `Error rule object`_ + - **Required**. Array of possible endpoint results. The implicit `NoMatchRule` at BDD reference 0 is not included + in the array. These rule objects MUST NOT contain conditions. + * - root + - ``integer`` + - **Required**. The root reference where BDD evaluation begins. + * - nodeCount + - ``integer`` + - **Required**. The total number of nodes in the BDD. Used for validation and exact-sizing arrays during + deserialization. + * - nodes + - ``string`` + - **Required**. Base64-encoded binary representation of BDD nodes. Each node is encoded as three 4-byte + integers: ``[conditionIndex, highRef, lowRef]``. + +.. _rules-engine-endpoint-bdd-node-structure: + +BDD node structure +------------------ + +Each BDD node is encoded as a triple of integers: + +* ``conditionIndex``: Zero-based index into the ``conditions`` array +* ``highRef``: Reference to follow when the condition evaluates to true +* ``lowRef``: Reference to follow when the condition evaluates to false + +The first node, index 0, is always the terminal node ``[-1, 1, -1]`` and MUST NOT be referenced directly. This node +serves as the canonical base case for BDD reduction algorithms. + +.. _rules-engine-endpoint-bdd-reference-encoding: + +Reference encoding +------------------ + +BDD references use the following encoding scheme: + +.. list-table:: + :header-rows: 1 + :widths: 20 80 + + * - Reference value + - Description + * - ``0`` + - Invalid/unused reference (never appears in valid BDDs) + * - ``1`` + - TRUE terminal (no match in endpoint resolution) + * - ``-1`` + - FALSE terminal (no match in endpoint resolution) + * - ``2, 3, 4, ...`` + - Node references (points to ``nodes[ref-1]``) + * - ``-2, -3, -4, ...`` + - Complement edges (logical NOT of the referenced node) + * - ``100000000+`` + - Result terminals (100000000 + resultIndex) + +When traversing a complement edge (negative reference), the high and low branches are swapped during evaluation. +This enables significant node sharing and BDD size reduction. + +.. _rules-engine-endpoint-bdd-binary-encoding: + +Binary node encoding +-------------------- + +Nodes are encoded as a Base64 string using binary encoding for efficiency: + +* Each node consists of three 4-byte big-endian integers +* Nodes are concatenated sequentially: ``[node0, node1, ..., nodeN-1]`` +* The resulting byte array is Base64-encoded + + +.. note:: Why binary? + + This encoding provides: + + * **Size efficiency**: smaller than an array of JSON integers, or an array of arrays of integers + * **Performance**: Direct deserialization into the target data structure (e.g., primitive arrays and integers) + * **Cleaner diffs**: BDD node changes appear as single-line modifications rather than spread over thousands + of numbers. + +.. _rules-engine-endpoint-bdd-evaluation: + +BDD evaluation +-------------- + +BDD evaluation follows these steps: + +#. Start at the root reference +#. While the reference is a node reference (not a terminal or result): + + * Extract the node index: ``nodeIndex = |ref| - 1`` + * Retrieve the node at that index + * Evaluate the condition at ``conditionIndex`` + * Determine which branch to follow: + + * If the reference is complemented (negative) AND condition is true: follow ``lowRef`` + * If the reference is complemented (negative) AND condition is false: follow ``highRef`` + * If the reference is positive AND condition is true: follow ``highRef`` + * If the reference is positive AND condition is false: follow ``lowRef`` + + * Update the reference to the chosen branch and continue + +#. When reaching a terminal or result: + + * For result references ≥ 100000000: return ``results[ref - 100000000]`` + * For terminals (1 or -1): return the ``NoMatchRule`` + +For example, a reference of 100000003 would return ``results[3]``, while a reference of 1 or -1 indicates no matching +rule was found. + +.. _rules-engine-endpoint-bdd-validation: + +Validation requirements +----------------------- + +* **Root reference**: MUST NOT be complemented (negative) +* **Reference validity**: All references MUST be valid: + + * ``0`` is forbidden + * Node references MUST point to existing nodes + * Result references MUST point to existing results + +* **Node structure**: Each node MUST be a properly formed triple +* **Condition indices**: Each node's condition index MUST be within ``[0, conditionCount)`` +* **Result structure**: The first result (index 0) implicitly represents ``NoMatchRule`` and is not serialized. + All serialized results MUST be either ``EndpointRule`` or ``ErrorRule`` objects without conditions. +* **Version requirement**: The version MUST be at least 1.1 + .. _rules-engine-endpoint-rule-set-parameter: ---------------- @@ -119,7 +305,7 @@ allow values to be bound to parameters from other locations in generated clients. Parameters MAY be annotated with the ``builtIn`` property, which designates that -the parameter should be bound to a value determined by the built-in’s name. The +the parameter should be bound to a value determined by the built-in's name. The :ref:`rules engine contains built-ins ` and the set is extensible. diff --git a/docs/source-2.0/additional-specs/rules-engine/standard-library.rst b/docs/source-2.0/additional-specs/rules-engine/standard-library.rst index f628d689286..6ce48785ef2 100644 --- a/docs/source-2.0/additional-specs/rules-engine/standard-library.rst +++ b/docs/source-2.0/additional-specs/rules-engine/standard-library.rst @@ -38,6 +38,53 @@ parameter is equal to the value ``false``: } +.. _rules-engine-standard-library-coalesce: + +``coalesce`` function +===================== + +Summary + Evaluates arguments in order and returns the first non-empty result, otherwise returns the result of the last + argument. +Argument types + * This function is variadic and requires two or more arguments, each of type ``T`` or ``option`` + * All arguments must have the same inner type ``T`` +Return type + * ``coalesce(T, T, ...)`` → ``T`` + * ``coalesce(option, T, ...)`` → ``T`` (if any argument is non-optional) + * ``coalesce(T, option, ...)`` → ``T`` (if any argument is non-optional) + * ``coalesce(option, option, ...)`` → ``option`` (if all arguments are optional) +Since + 1.1 + +The ``coalesce`` function provides null-safe chaining by evaluating arguments in order and returning the first +non-empty result. If all arguments leading up to the last argument evaluate to empty, it returns the result of the +last argument. This is particularly useful for providing default values for optional parameters, chaining multiple +optional values together, and related optimizations. + +The function accepts two or more arguments, all of which must have the same inner type after unwrapping any +optionals. The return type is ``option`` only if all arguments are ``option``; otherwise it returns ``T``. + +The following example demonstrates using ``coalesce`` with multiple arguments to try several optional values +in sequence: + +.. code-block:: json + + { + "fn": "coalesce", + "argv": [ + {"ref": "customEndpoint"}, + {"ref": "regionalEndpoint"}, + {"ref": "defaultEndpoint"} + ] + } + +.. important:: + All arguments must have the same type after unwrapping any optionals (types are known at compile time and do not + need to be validated at runtime). Note that the first non-empty result is returned even if it's ``false`` + (coalesce is looking for a *non-empty* value, not a truthy value). + + .. _rules-engine-standard-library-getAttr: ``getAttr`` function diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/AwsArn.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/AwsArn.java index c52fe6c727b..f4d05b6503f 100644 --- a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/AwsArn.java +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/AwsArn.java @@ -4,7 +4,7 @@ */ package software.amazon.smithy.rulesengine.aws.language.functions; -import java.util.Arrays; +import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -39,28 +39,67 @@ private AwsArn(Builder builder) { * @return the optional ARN. */ public static Optional parse(String arn) { - String[] base = arn.split(":", 6); - if (base.length != 6) { + if (arn == null || arn.length() < 8 || !arn.startsWith("arn:")) { return Optional.empty(); } - // First section must be "arn". - if (!base[0].equals("arn")) { + + // find each of the first five ':' positions + int p0 = 3; // after "arn" + int p1 = arn.indexOf(':', p0 + 1); + if (p1 < 0) { + return Optional.empty(); + } + + int p2 = arn.indexOf(':', p1 + 1); + if (p2 < 0) { + return Optional.empty(); + } + + int p3 = arn.indexOf(':', p2 + 1); + if (p3 < 0) { return Optional.empty(); } - // Sections for partition, service, and resource type must not be empty. - if (base[1].isEmpty() || base[2].isEmpty() || base[5].isEmpty()) { + + int p4 = arn.indexOf(':', p3 + 1); + if (p4 < 0) { + return Optional.empty(); + } + + // extract and validate mandatory parts + String partition = arn.substring(p0 + 1, p1); + String service = arn.substring(p1 + 1, p2); + String region = arn.substring(p2 + 1, p3); + String accountId = arn.substring(p3 + 1, p4); + String resource = arn.substring(p4 + 1); + + if (partition.isEmpty() || service.isEmpty() || resource.isEmpty()) { return Optional.empty(); } return Optional.of(builder() - .partition(base[1]) - .service(base[2]) - .region(base[3]) - .accountId(base[4]) - .resource(Arrays.asList(base[5].split("[:/]", -1))) + .partition(partition) + .service(service) + .region(region) + .accountId(accountId) + .resource(splitResource(resource)) .build()); } + private static List splitResource(String resource) { + List result = new ArrayList<>(); + int start = 0; + int length = resource.length(); + for (int i = 0; i < length; i++) { + char c = resource.charAt(i); + if (c == ':' || c == '/') { + result.add(resource.substring(start, i)); + start = i + 1; + } + } + result.add(resource.substring(start)); + return result; + } + /** * Builder to create an {@link AwsArn} instance. * diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/validators/RuleSetAwsBuiltInValidator.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/validators/RuleSetAwsBuiltInValidator.java index e0c5cdd0663..f149ab7c52b 100644 --- a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/validators/RuleSetAwsBuiltInValidator.java +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/validators/RuleSetAwsBuiltInValidator.java @@ -6,7 +6,6 @@ import java.util.ArrayList; import java.util.List; -import java.util.Optional; import java.util.Set; import software.amazon.smithy.model.FromSourceLocation; import software.amazon.smithy.model.Model; @@ -14,8 +13,8 @@ import software.amazon.smithy.model.validation.AbstractValidator; import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.aws.language.functions.AwsBuiltIns; -import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; import software.amazon.smithy.utils.SetUtils; @@ -33,36 +32,30 @@ public class RuleSetAwsBuiltInValidator extends AbstractValidator { @Override public List validate(Model model) { List events = new ArrayList<>(); - for (ServiceShape serviceShape : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { - events.addAll(validateRuleSetAwsBuiltIns(serviceShape, - serviceShape.expectTrait(EndpointRuleSetTrait.class) - .getEndpointRuleSet())); + + for (ServiceShape s : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { + EndpointRuleSetTrait trait = s.expectTrait(EndpointRuleSetTrait.class); + validateRuleSetAwsBuiltIns(events, s, trait.getEndpointRuleSet().getParameters()); + } + + for (ServiceShape s : model.getServiceShapesWithTrait(EndpointBddTrait.class)) { + validateRuleSetAwsBuiltIns(events, s, s.expectTrait(EndpointBddTrait.class).getParameters()); } + return events; } - private List validateRuleSetAwsBuiltIns(ServiceShape serviceShape, EndpointRuleSet ruleSet) { - List events = new ArrayList<>(); - for (Parameter parameter : ruleSet.getParameters()) { + private void validateRuleSetAwsBuiltIns(List events, ServiceShape s, Iterable params) { + for (Parameter parameter : params) { if (parameter.isBuiltIn()) { - validateBuiltIn(serviceShape, parameter.getBuiltIn().get(), parameter).ifPresent(events::add); + validateBuiltIn(events, s, parameter.getBuiltIn().get(), parameter); } } - return events; } - private Optional validateBuiltIn( - ServiceShape serviceShape, - String builtInName, - FromSourceLocation source - ) { - if (ADDITIONAL_CONSIDERATION_BUILT_INS.contains(builtInName)) { - return Optional.of(danger( - serviceShape, - source, - String.format(ADDITIONAL_CONSIDERATION_MESSAGE, builtInName), - builtInName)); + private void validateBuiltIn(List events, ServiceShape s, String name, FromSourceLocation source) { + if (ADDITIONAL_CONSIDERATION_BUILT_INS.contains(name)) { + events.add(danger(s, source, String.format(ADDITIONAL_CONSIDERATION_MESSAGE, name), name)); } - return Optional.empty(); } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/analysis/BddCoverageChecker.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/analysis/BddCoverageChecker.java new file mode 100644 index 00000000000..d2b3443163d --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/analysis/BddCoverageChecker.java @@ -0,0 +1,160 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.analysis; + +import java.util.BitSet; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.rulesengine.language.evaluation.RuleEvaluator; +import software.amazon.smithy.rulesengine.language.evaluation.value.Value; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.ConditionEvaluator; +import software.amazon.smithy.rulesengine.logic.bdd.Bdd; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; +import software.amazon.smithy.rulesengine.traits.EndpointTestCase; + +/** + * Analyzes test coverage for BDD-based endpoint rules. + */ +public final class BddCoverageChecker { + + private final Parameters parameters; + private final Bdd bdd; + private final List conditions; + private final List results; + private final BitSet visitedConditions; + private final BitSet visitedResults; + + public BddCoverageChecker(EndpointBddTrait bddTrait) { + this(bddTrait.getParameters(), bddTrait.getBdd(), bddTrait.getResults(), bddTrait.getConditions()); + } + + BddCoverageChecker(Parameters parameters, Bdd bdd, List results, List conditions) { + this.results = results; + this.parameters = parameters; + this.conditions = conditions; + this.bdd = bdd; + this.visitedConditions = new BitSet(conditions.size()); + this.visitedResults = new BitSet(results.size()); + } + + /** + * Evaluates a test case and updates coverage information. + * + * @param testCase the test case to evaluate + */ + public void evaluateTestCase(EndpointTestCase testCase) { + Map input = new LinkedHashMap<>(); + for (Map.Entry entry : testCase.getParams().getStringMap().entrySet()) { + input.put(Identifier.of(entry.getKey()), Value.fromNode(entry.getValue())); + } + evaluateInput(input); + } + + /** + * Evaluates with the given inputs and updates coverage. + * + * @param input the input parameters to evaluate + */ + public void evaluateInput(Map input) { + TestEvaluator evaluator = new TestEvaluator(input); + int resultIdx = bdd.evaluate(evaluator); + if (resultIdx >= 0) { + visitedResults.set(resultIdx); + } + } + + /** + * Returns conditions that were never evaluated during testing. + * + * @return set of unevaluated conditions + */ + public Set getUnevaluatedConditions() { + Set unevaluated = new HashSet<>(); + for (int i = 0; i < conditions.size(); i++) { + if (!visitedConditions.get(i)) { + unevaluated.add(conditions.get(i)); + } + } + return unevaluated; + } + + /** + * Returns results that were never reached during testing. + * + * @return set of unreached results + */ + public Set getUnevaluatedResults() { + Set unevaluated = new HashSet<>(); + for (int i = 0; i < results.size(); i++) { + if (!visitedResults.get(i)) { + Rule result = results.get(i); + if (!(result instanceof NoMatchRule)) { + unevaluated.add(result); + } + } + } + return unevaluated; + } + + /** + * Returns the percentage of conditions that were evaluated at least once. + * + * @return condition coverage percentage (0-100) + */ + public double getConditionCoverage() { + return conditions.isEmpty() ? 100.0 : (100.0 * visitedConditions.cardinality() / conditions.size()); + } + + /** + * Returns the percentage of results that were reached at least once. + * + * @return result coverage percentage (0-100) + */ + public double getResultCoverage() { + // Count only non-NO_MATCH results + int relevantResults = 0; + int coveredRelevantResults = 0; + + for (int i = 0; i < results.size(); i++) { + if (!(results.get(i) instanceof NoMatchRule)) { + relevantResults++; + if (visitedResults.get(i)) { + coveredRelevantResults++; + } + } + } + + return relevantResults == 0 ? 100.0 : (100.0 * coveredRelevantResults / relevantResults); + } + + // Evaluator that tracks what gets visited during BDD evaluation. + private final class TestEvaluator implements ConditionEvaluator { + private final RuleEvaluator ruleEvaluator; + + TestEvaluator(Map input) { + this.ruleEvaluator = new RuleEvaluator(parameters, input); + } + + @Override + public boolean test(int conditionIndex) { + if (conditionIndex < 0 || conditionIndex >= conditions.size()) { + return false; + } else { + visitedConditions.set(conditionIndex); + Condition condition = conditions.get(conditionIndex); + return ruleEvaluator.evaluateCondition(condition).isTruthy(); + } + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java index 240422aa4c4..ba7d5113c12 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java @@ -6,6 +6,7 @@ import java.util.List; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionDefinition; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; @@ -36,6 +37,7 @@ public List getLibraryFunctions() { IsSet.getDefinition(), IsValidHostLabel.getDefinition(), Not.getDefinition(), + Coalesce.getDefinition(), ParseUrl.getDefinition(), StringEquals.getDefinition(), Substring.getDefinition(), diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/Endpoint.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/Endpoint.java index b1b23f7a997..f9b407fb196 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/Endpoint.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/Endpoint.java @@ -206,21 +206,21 @@ public int hashCode() { @Override public String toString() { StringBuilder sb = new StringBuilder(); - sb.append("url: ").append(url).append("\n"); + sb.append("url: ").append(url); if (!headers.isEmpty()) { - sb.append("headers:\n"); + sb.append("\nheaders:"); for (Map.Entry> entry : headers.entrySet()) { - sb.append(StringUtils.indent(String.format("%s: %s", entry.getKey(), entry.getValue()), 2)) - .append("\n"); + sb.append("\n"); + sb.append(StringUtils.indent(String.format("%s: %s", entry.getKey(), entry.getValue()), 2)); } } if (!properties.isEmpty()) { - sb.append("properties:\n"); + sb.append("\nproperties:"); for (Map.Entry entry : properties.entrySet()) { - sb.append(StringUtils.indent(String.format("%s: %s", entry.getKey(), entry.getValue()), 2)) - .append("\n"); + sb.append("\n"); + sb.append(StringUtils.indent(String.format("%s: %s", entry.getKey(), entry.getValue()), 2)); } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/EndpointRuleSet.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/EndpointRuleSet.java index 9f25f9ce426..57c1b423d22 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/EndpointRuleSet.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/EndpointRuleSet.java @@ -58,6 +58,7 @@ private static final class LazyEndpointComponentFactoryHolder { private final List rules; private final SourceLocation sourceLocation; private final String version; + private final RulesVersion rulesVersion; private EndpointRuleSet(Builder builder) { super(); @@ -65,6 +66,7 @@ private EndpointRuleSet(Builder builder) { rules = builder.rules.copy(); sourceLocation = SmithyBuilder.requiredState("source", builder.getSourceLocation()); version = SmithyBuilder.requiredState(VERSION, builder.version); + rulesVersion = RulesVersion.of(version); } /** @@ -130,6 +132,15 @@ public String getVersion() { return version; } + /** + * Get the parsed rules engine version. + * + * @return parsed version. + */ + public RulesVersion getRulesVersion() { + return rulesVersion; + } + public Type typeCheck() { return typeCheck(new Scope<>()); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/RulesVersion.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/RulesVersion.java new file mode 100644 index 00000000000..d0302067dce --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/RulesVersion.java @@ -0,0 +1,143 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.language; + +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import software.amazon.smithy.utils.SmithyUnstableApi; +import software.amazon.smithy.utils.StringUtils; + +/** + * Represents the rules engine version with major and minor components. + */ +@SmithyUnstableApi +public final class RulesVersion implements Comparable { + + private static final ConcurrentHashMap CACHE = new ConcurrentHashMap<>(); + + public static final RulesVersion V1_0 = of("1.0"); + public static final RulesVersion V1_1 = of("1.1"); + + private final int major; + private final int minor; + private final String stringValue; + private final int hashCode; + + private RulesVersion(int major, int minor) { + if (major < 0 || minor < 0) { + throw new IllegalArgumentException("Version components must be non-negative"); + } + + this.major = major; + this.minor = minor; + this.stringValue = major + "." + minor; + this.hashCode = Objects.hash(major, minor); + } + + /** + * Creates a RulesVersion from a string representation. + * + * @param version the version string (e.g., "1.0", "1.2") + * @return the RulesVersion instance + * @throws IllegalArgumentException if the version string is invalid + */ + public static RulesVersion of(String version) { + return CACHE.computeIfAbsent(version, RulesVersion::parse); + } + + /** + * Creates a RulesVersion from components. + * + * @param major the major version + * @param minor the minor version + * @return the RulesVersion instance + */ + public static RulesVersion of(int major, int minor) { + String key = major + "." + minor; + return CACHE.computeIfAbsent(key, k -> new RulesVersion(major, minor)); + } + + private static RulesVersion parse(String version) { + if (StringUtils.isEmpty(version)) { + throw new IllegalArgumentException("Version string cannot be null or empty"); + } + + String[] parts = version.split("\\."); + if (parts.length < 2) { + throw new IllegalArgumentException("Invalid version: `" + version + "`. Expected format: major.minor"); + } + + try { + int major = Integer.parseInt(parts[0]); + int minor = Integer.parseInt(parts[1]); + return new RulesVersion(major, minor); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid version format: " + version, e); + } + } + + /** + * Gets the major version component. + * + * @return the major version + */ + public int getMajor() { + return major; + } + + /** + * Gets the minor version component. + * + * @return the minor version + */ + public int getMinor() { + return minor; + } + + /** + * Checks if this version is at least the specified version. + * + * @param other the version to compare against + * @return true if this version >= other + */ + public boolean isAtLeast(RulesVersion other) { + return compareTo(other) >= 0; + } + + @Override + public int compareTo(RulesVersion other) { + if (this == other) { + return 0; + } + + int result = Integer.compare(major, other.major); + if (result != 0) { + return result; + } else { + return Integer.compare(minor, other.minor); + } + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } else if (!(obj instanceof RulesVersion)) { + return false; + } + RulesVersion other = (RulesVersion) obj; + return major == other.major && minor == other.minor; + } + + @Override + public int hashCode() { + return hashCode; + } + + @Override + public String toString() { + return stringValue; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java index da81e346228..6e8d70a8771 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java @@ -19,9 +19,15 @@ import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.language.syntax.rule.RuleValueVisitor; +import software.amazon.smithy.rulesengine.logic.RuleBasedConditionEvaluator; +import software.amazon.smithy.rulesengine.logic.bdd.Bdd; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; import software.amazon.smithy.utils.SmithyUnstableApi; /** @@ -31,6 +37,23 @@ public class RuleEvaluator implements ExpressionVisitor { private final Scope scope = new Scope<>(); + public RuleEvaluator() {} + + /** + * Create a rule evaluator from parameters and using an initial set of arguments. + * + *

This is primarily used for manually driven condition evaluation. + * + * @param parameters Parameters of the evaluator, used to initialize defaults and parameters. + * @param parameterArguments Arguments used to initialize evaluation scope state. + */ + public RuleEvaluator(Parameters parameters, Map parameterArguments) { + for (Parameter parameter : parameters) { + parameter.getDefault().ifPresent(value -> scope.insert(parameter.getName(), value)); + } + parameterArguments.forEach(scope::insert); + } + /** * Initializes a new {@link RuleEvaluator} instances, and evaluates * the provided ruleset and parameter arguments. @@ -44,6 +67,67 @@ public static Value evaluate(EndpointRuleSet ruleset, Map par return new RuleEvaluator().evaluateRuleSet(ruleset, parameterArguments); } + /** + * Initializes a new {@link RuleEvaluator} instances, and evaluates the provided BDD and parameter arguments. + * + * @param trait The trait to evaluate. + * @param args The rule-set parameter identifiers and values to evaluate the BDD against. + * @return The resulting value from the final matched rule. + */ + public static Value evaluate(EndpointBddTrait trait, Map args) { + return evaluate(trait.getBdd(), trait.getParameters(), trait.getConditions(), trait.getResults(), args); + } + + /** + * Initializes a new {@link RuleEvaluator} instances, and evaluates the provided BDD and parameter arguments. + * + * @param bdd The endpoint bdd. + * @param parameterArguments The rule-set parameter identifiers and values to evaluate the BDD against. + * @return The resulting value from the final matched rule. + */ + public static Value evaluate( + Bdd bdd, + Parameters parameters, + List conditions, + List results, + Map parameterArguments + ) { + return new RuleEvaluator().evaluateBdd(bdd, parameters, conditions, results, parameterArguments); + } + + private Value evaluateBdd( + Bdd bdd, + Parameters parameters, + List conditions, + List results, + Map parameterArguments + ) { + return scope.inScope(() -> { + for (Parameter parameter : parameters) { + parameter.getDefault().ifPresent(value -> scope.insert(parameter.getName(), value)); + } + + parameterArguments.forEach(scope::insert); + + Condition[] conds = conditions.toArray(new Condition[0]); + RuleBasedConditionEvaluator conditionEvaluator = new RuleBasedConditionEvaluator(this, conds); + int result = bdd.evaluate(conditionEvaluator); + + if (result < 0) { + throw new RuntimeException("No BDD result matched"); + } + + Rule rule = results.get(result); + if (rule instanceof EndpointRule) { + return resolveEndpoint(this, ((EndpointRule) rule).getEndpoint()); + } else if (rule instanceof ErrorRule) { + return resolveError(this, ((ErrorRule) rule).getError()); + } else { + throw new RuntimeException("Invalid BDD rule result: " + rule); + } + }); + } + /** * Evaluate the provided ruleset and parameter arguments. * @@ -101,6 +185,17 @@ public Value visitIsSet(Expression fn) { return Value.booleanValue(!fn.accept(this).isEmpty()); } + @Override + public Value visitCoalesce(List expressions) { + for (Expression exp : expressions) { + Value result = exp.accept(this); + if (!result.isEmpty()) { + return result; + } + } + return Value.emptyValue(); + } + @Override public Value visitNot(Expression not) { return Value.booleanValue(!not.accept(this).expectBooleanValue().getValue()); @@ -139,7 +234,7 @@ private Value handleRule(Rule rule) { return scope.inScope(() -> { for (Condition condition : rule.getConditions()) { Value value = evaluateCondition(condition); - if (value.isEmpty() || value.equals(Value.booleanValue(false))) { + if (!value.isTruthy()) { return Value.emptyValue(); } } @@ -159,32 +254,40 @@ public Value visitTreeRule(List rules) { @Override public Value visitErrorRule(Expression error) { - return error.accept(self); + return resolveError(self, error); } @Override public Value visitEndpointRule(Endpoint endpoint) { - EndpointValue.Builder builder = EndpointValue.builder() - .sourceLocation(endpoint) - .url(endpoint.getUrl() - .accept(RuleEvaluator.this) - .expectStringValue() - .getValue()); - - for (Map.Entry entry : endpoint.getProperties().entrySet()) { - builder.putProperty(entry.getKey().toString(), entry.getValue().accept(RuleEvaluator.this)); - } - - for (Map.Entry> entry : endpoint.getHeaders().entrySet()) { - List values = new ArrayList<>(); - for (Expression expression : entry.getValue()) { - values.add(expression.accept(RuleEvaluator.this).expectStringValue().getValue()); - } - builder.putHeader(entry.getKey(), values); - } - return builder.build(); + return resolveEndpoint(self, endpoint); } }); }); } + + private static Value resolveEndpoint(RuleEvaluator self, Endpoint endpoint) { + EndpointValue.Builder builder = EndpointValue.builder() + .sourceLocation(endpoint) + .url(endpoint.getUrl() + .accept(self) + .expectStringValue() + .getValue()); + + for (Map.Entry entry : endpoint.getProperties().entrySet()) { + builder.putProperty(entry.getKey().toString(), entry.getValue().accept(self)); + } + + for (Map.Entry> entry : endpoint.getHeaders().entrySet()) { + List values = new ArrayList<>(); + for (Expression expression : entry.getValue()) { + values.add(expression.accept(self).expectStringValue().getValue()); + } + builder.putHeader(entry.getKey(), values); + } + return builder.build(); + } + + private static Value resolveError(RuleEvaluator self, Expression error) { + return error.accept(self); + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/TestEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/TestEvaluator.java index b80efa02c61..2c3fbdda3ce 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/TestEvaluator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/TestEvaluator.java @@ -13,6 +13,7 @@ import software.amazon.smithy.rulesengine.language.evaluation.value.EndpointValue; import software.amazon.smithy.rulesengine.language.evaluation.value.Value; import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; import software.amazon.smithy.rulesengine.traits.EndpointTestCase; import software.amazon.smithy.rulesengine.traits.EndpointTestExpectation; import software.amazon.smithy.rulesengine.traits.ExpectedEndpoint; @@ -34,12 +35,30 @@ private TestEvaluator() {} * @param testCase The test case. */ public static void evaluate(EndpointRuleSet ruleset, EndpointTestCase testCase) { + Value result = RuleEvaluator.evaluate(ruleset, createParams(testCase)); + processResult(result, testCase); + } + + /** + * Evaluate the given BDD and test case. Throws an exception in the event the test case does not pass. + * + * @param bdd The BDD trait to be tested. + * @param testCase The test case. + */ + public static void evaluate(EndpointBddTrait bdd, EndpointTestCase testCase) { + Value result = RuleEvaluator.evaluate(bdd, createParams(testCase)); + processResult(result, testCase); + } + + private static Map createParams(EndpointTestCase testCase) { Map parameters = new LinkedHashMap<>(); for (Map.Entry entry : testCase.getParams().getMembers().entrySet()) { parameters.put(Identifier.of(entry.getKey()), Value.fromNode(entry.getValue())); } - Value result = RuleEvaluator.evaluate(ruleset, parameters); + return parameters; + } + private static void processResult(Value result, EndpointTestCase testCase) { StringBuilder messageBuilder = new StringBuilder("while executing test case"); if (testCase.getDocumentation().isPresent()) { messageBuilder.append(" ").append(testCase.getDocumentation().get()); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/AnyType.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/AnyType.java index d23096dd3d9..799d15a3453 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/AnyType.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/AnyType.java @@ -10,6 +10,9 @@ * The "any" type, which matches all other types. */ public final class AnyType extends AbstractType { + + static final AnyType INSTANCE = new AnyType(); + AnyType() {} @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/ArrayType.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/ArrayType.java index ba20c757ca9..faeadf6e358 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/ArrayType.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/ArrayType.java @@ -4,12 +4,17 @@ */ package software.amazon.smithy.rulesengine.language.evaluation.type; +import java.util.Collections; import java.util.Objects; +import java.util.Optional; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; /** * The "array" type, which contains entries of a member type. */ public final class ArrayType extends AbstractType { + + private static final Optional ZERO = Optional.of(Literal.tupleLiteral(Collections.emptyList())); private final Type member; ArrayType(Type member) { @@ -51,4 +56,9 @@ public int hashCode() { public String toString() { return String.format("ArrayType[%s]", member); } + + @Override + public Optional getZeroValue() { + return ZERO; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/BooleanType.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/BooleanType.java index 8c58f670224..af0cfd62543 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/BooleanType.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/BooleanType.java @@ -4,14 +4,26 @@ */ package software.amazon.smithy.rulesengine.language.evaluation.type; +import java.util.Optional; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; + /** * The "boolean" type. */ public final class BooleanType extends AbstractType { + + private static final Optional ZERO = Optional.of(Literal.of(false)); + static final BooleanType INSTANCE = new BooleanType(); + BooleanType() {} @Override public BooleanType expectBooleanType() { return this; } + + @Override + public Optional getZeroValue() { + return ZERO; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/EmptyType.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/EmptyType.java index 86ec7490935..7e9fec629de 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/EmptyType.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/EmptyType.java @@ -10,6 +10,9 @@ * The "empty" type. */ public final class EmptyType extends AbstractType { + + static final EmptyType INSTANCE = new EmptyType(); + EmptyType() {} @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/EndpointType.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/EndpointType.java index b10d57f667d..65a25d4ca97 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/EndpointType.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/EndpointType.java @@ -10,6 +10,9 @@ * The "endpoint" type, representing a valid client endpoint. */ public final class EndpointType extends AbstractType { + + static final EndpointType INSTANCE = new EndpointType(); + EndpointType() {} @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/IntegerType.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/IntegerType.java index 33a71e2a58c..713e047cac3 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/IntegerType.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/IntegerType.java @@ -4,14 +4,26 @@ */ package software.amazon.smithy.rulesengine.language.evaluation.type; +import java.util.Optional; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; + /** * The "integer" type. */ public final class IntegerType extends AbstractType { + + private static final Optional ZERO = Optional.of(Literal.of(0)); + static final IntegerType INSTANCE = new IntegerType(); + IntegerType() {} @Override public IntegerType expectIntegerType() { return this; } + + @Override + public Optional getZeroValue() { + return ZERO; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/OptionalType.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/OptionalType.java index d96389ce3b4..fa6150a16c9 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/OptionalType.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/OptionalType.java @@ -5,7 +5,9 @@ package software.amazon.smithy.rulesengine.language.evaluation.type; import java.util.Objects; +import java.util.Optional; import software.amazon.smithy.rulesengine.language.error.InnerParseError; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; /** * The "optional" type, a container for a type that may or may not be present. @@ -78,4 +80,9 @@ public int hashCode() { public String toString() { return String.format("OptionalType[%s]", inner); } + + @Override + public Optional getZeroValue() { + return inner.getZeroValue(); + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/StringType.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/StringType.java index 9738a474e4f..1546d8113ca 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/StringType.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/StringType.java @@ -4,14 +4,26 @@ */ package software.amazon.smithy.rulesengine.language.evaluation.type; +import java.util.Optional; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; + /** * The "string" type. */ public final class StringType extends AbstractType { + + private static final Optional ZERO = Optional.of(Literal.of("")); + static final StringType INSTANCE = new StringType(); + StringType() {} @Override public StringType expectStringType() { return this; } + + @Override + public Optional getZeroValue() { + return ZERO; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/Type.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/Type.java index 15738d40b83..1313819497e 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/Type.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/Type.java @@ -6,8 +6,10 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import software.amazon.smithy.rulesengine.language.error.InnerParseError; import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; import software.amazon.smithy.utils.SmithyUnstableApi; @@ -38,20 +40,20 @@ default Type provenTruthy() { } static Type fromParameterType(ParameterType parameterType) { - if (parameterType == ParameterType.STRING) { - return stringType(); + switch (parameterType) { + case STRING: + return stringType(); + case BOOLEAN: + return booleanType(); + case STRING_ARRAY: + return arrayType(stringType()); + default: + throw new IllegalArgumentException("Unexpected parameter type: " + parameterType); } - if (parameterType == ParameterType.BOOLEAN) { - return booleanType(); - } - if (parameterType == ParameterType.STRING_ARRAY) { - return arrayType(stringType()); - } - throw new IllegalArgumentException("Unexpected parameter type: " + parameterType); } static AnyType anyType() { - return new AnyType(); + return AnyType.INSTANCE; } static ArrayType arrayType(Type inner) { @@ -59,19 +61,19 @@ static ArrayType arrayType(Type inner) { } static BooleanType booleanType() { - return new BooleanType(); + return BooleanType.INSTANCE; } static EmptyType emptyType() { - return new EmptyType(); + return EmptyType.INSTANCE; } static EndpointType endpointType() { - return new EndpointType(); + return EndpointType.INSTANCE; } static IntegerType integerType() { - return new IntegerType(); + return IntegerType.INSTANCE; } static OptionalType optionalType(Type type) { @@ -83,7 +85,7 @@ static RecordType recordType(Map inner) { } static StringType stringType() { - return new StringType(); + return StringType.INSTANCE; } static TupleType tupleType(List members) { @@ -129,4 +131,17 @@ default StringType expectStringType() throws InnerParseError { default TupleType expectTupleType() throws InnerParseError { throw new InnerParseError("Expected tuple but found " + this); } + + /** + * Gets the default zero-value of the type as a Literal. + * + *

Strings, booleans, integers, and arrays have zero values. Other types do not. E.g., a map might have + * required properties, and the behavior of a tuple _seems_ to imply that each member is required. Optionals + * return the zero value of its inner type. + * + * @return the default zero value. + */ + default Optional getZeroValue() { + return Optional.empty(); + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/ArrayValue.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/ArrayValue.java index 40eeb6cc48a..64b52bad7ad 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/ArrayValue.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/ArrayValue.java @@ -101,4 +101,13 @@ public String toString() { } return "[" + String.join(", ", valueStrings) + "]"; } + + @Override + public Object toObject() { + List result = new ArrayList<>(values.size()); + for (Value value : values) { + result.add(value.toObject()); + } + return result; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/BooleanValue.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/BooleanValue.java index 43b7b145223..e1b5f0709c5 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/BooleanValue.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/BooleanValue.java @@ -14,13 +14,26 @@ * A boolean value of true or false. */ public final class BooleanValue extends Value { + + static final BooleanValue TRUE = new BooleanValue(true); + static final BooleanValue FALSE = new BooleanValue(false); + private final boolean value; - BooleanValue(boolean value) { + static BooleanValue create(boolean v) { + return v ? TRUE : FALSE; + } + + private BooleanValue(boolean value) { super(SourceLocation.none()); this.value = value; } + @Override + public boolean isTruthy() { + return value; + } + /** * Gets the true or false value of this boolean. * @@ -68,4 +81,9 @@ public int hashCode() { public String toString() { return String.valueOf(value); } + + @Override + public Object toObject() { + return value; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EmptyValue.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EmptyValue.java index 392214b450c..ad58c24fbb5 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EmptyValue.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EmptyValue.java @@ -12,10 +12,17 @@ * An empty value. */ public final class EmptyValue extends Value { + static final EmptyValue INSTANCE = new EmptyValue(); + public EmptyValue() { super(SourceLocation.none()); } + @Override + public boolean isTruthy() { + return false; + } + @Override public Type getType() { return Type.emptyType(); @@ -35,4 +42,9 @@ public Node toNode() { public String toString() { return ""; } + + @Override + public Object toObject() { + return null; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EndpointValue.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EndpointValue.java index 324c71ab92f..b6867589589 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EndpointValue.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EndpointValue.java @@ -163,6 +163,11 @@ public String toString() { return sb.toString(); } + @Override + public Object toObject() { + return this; + } + /** * A builder used to create an {@link EndpointValue} class. */ diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/IntegerValue.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/IntegerValue.java index 99fe951ba5b..f8311556a91 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/IntegerValue.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/IntegerValue.java @@ -42,4 +42,9 @@ public IntegerValue expectIntegerValue() { public Node toNode() { return Node.from(value); } + + @Override + public Object toObject() { + return value; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/RecordValue.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/RecordValue.java index 20e7bd45a72..7a3b23abafe 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/RecordValue.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/RecordValue.java @@ -97,4 +97,13 @@ public int hashCode() { public String toString() { return value.toString(); } + + @Override + public Object toObject() { + Map result = new HashMap<>(value.size()); + for (Map.Entry entry : value.entrySet()) { + result.put(entry.getKey().toString(), entry.getValue().toObject()); + } + return result; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/StringValue.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/StringValue.java index 44f48d3cbbd..f1c659b08e3 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/StringValue.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/StringValue.java @@ -66,4 +66,9 @@ public int hashCode() { public String toString() { return value; } + + @Override + public Object toObject() { + return value; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/Value.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/Value.java index 7d719e1cd5b..cfbb29422c5 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/Value.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/Value.java @@ -38,6 +38,10 @@ public abstract class Value implements FromSourceLocation, ToNode { public abstract Type getType(); + public boolean isTruthy() { + return true; + } + /** * Creates a {@link Value} of a specific type from the given Node information. * @@ -112,7 +116,7 @@ public static ArrayValue arrayValue(List value) { * @return returns the created BooleanValue. */ public static BooleanValue booleanValue(boolean value) { - return new BooleanValue(value); + return BooleanValue.create(value); } /** @@ -121,7 +125,7 @@ public static BooleanValue booleanValue(boolean value) { * @return returns the created EmptyValue. */ public static EmptyValue emptyValue() { - return new EmptyValue(); + return EmptyValue.INSTANCE; } /** @@ -230,4 +234,6 @@ private RuntimeException throwTypeMismatch(String expectedType) { getType(), this)); } + + public abstract Object toObject(); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/Identifier.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/Identifier.java index 78dbaf4d2bb..cf0540cd775 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/Identifier.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/Identifier.java @@ -80,7 +80,7 @@ public boolean equals(Object obj) { @Override public int hashCode() { - return Objects.hash(name); + return name.hashCode(); } @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/SyntaxElement.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/SyntaxElement.java index de0a0f623eb..fc798ec8301 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/SyntaxElement.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/SyntaxElement.java @@ -4,6 +4,7 @@ */ package software.amazon.smithy.rulesengine.language.syntax; +import software.amazon.smithy.rulesengine.language.RulesVersion; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; @@ -22,6 +23,15 @@ */ @SmithyInternalApi public abstract class SyntaxElement implements ToCondition, ToExpression { + /** + * Get the rules engine version that this syntax element is available since. + * + * @return the version this is available since. + */ + public RulesVersion availableSince() { + return RulesVersion.V1_0; + } + /** * Returns a BooleanEquals expression comparing this expression to the provided boolean value. * diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/ToCondition.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/ToCondition.java index e70e946a35e..c69bfa44c3e 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/ToCondition.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/ToCondition.java @@ -17,8 +17,11 @@ public interface ToCondition { * Convert this into a condition builder for compositional use. * * @return the condition builder. + * @throws UnsupportedOperationException if this cannot be converted to a condition. */ - Condition.Builder toConditionBuilder(); + default Condition.Builder toConditionBuilder() { + throw new UnsupportedOperationException("Cannot convert " + getClass().getName() + " to a condition"); + } /** * Convert this into a condition. diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Expression.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Expression.java index 46fb4c5d5b6..814c5c7321f 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Expression.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Expression.java @@ -6,8 +6,10 @@ import static software.amazon.smithy.rulesengine.language.error.RuleError.context; +import java.util.Collections; import java.util.Objects; import java.util.Optional; +import java.util.Set; import software.amazon.smithy.model.FromSourceLocation; import software.amazon.smithy.model.SourceException; import software.amazon.smithy.model.SourceLocation; @@ -24,7 +26,6 @@ import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionNode; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; -import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.utils.SmithyUnstableApi; /** @@ -35,6 +36,7 @@ @SmithyUnstableApi public abstract class Expression extends SyntaxElement implements FromSourceLocation, ToNode, TypeCheck { private final SourceLocation sourceLocation; + private Set cachedReferences; private Type cachedType; public Expression(SourceLocation sourceLocation) { @@ -121,6 +123,26 @@ public static Reference getReference(Identifier name, FromSourceLocation context return new Reference(name, context); } + /** + * Constructs a {@link Reference} for the given {@link Identifier}. + * + * @param name the referenced identifier. + * @return the reference. + */ + public static Reference getReference(Identifier name) { + return getReference(name, SourceLocation.NONE); + } + + /** + * Constructs a {@link Reference} for the given {@link Identifier}. + * + * @param name the referenced identifier. + * @return the reference. + */ + public static Reference getReference(String name) { + return getReference(Identifier.of(name)); + } + /** * Constructs a {@link Literal} from the given {@link StringNode}. * @@ -131,6 +153,27 @@ public static Literal getLiteral(StringNode node) { return Literal.stringLiteral(new Template(node)); } + /** + * Get the set of variables that this condition references. + * + * @return variable references by name. + */ + public final Set getReferences() { + if (cachedReferences == null) { + cachedReferences = Collections.unmodifiableSet(calculateReferences()); + } + return cachedReferences; + } + + /** + * Computes the references of an expression. + * + * @return the computed references. + */ + protected Set calculateReferences() { + return Collections.emptySet(); + } + /** * Invoke the {@link ExpressionVisitor} functions for this expression. * @@ -154,11 +197,6 @@ public Type type() { return cachedType; } - @Override - public Condition.Builder toConditionBuilder() { - return Condition.builder().fn(this); - } - @Override public Expression toExpression() { return this; diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java index 8d9da440bf0..1557b529b52 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java @@ -5,6 +5,7 @@ package software.amazon.smithy.rulesengine.language.syntax.expressions; import java.util.List; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionDefinition; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; @@ -49,6 +50,16 @@ public interface ExpressionVisitor { */ R visitIsSet(Expression fn); + /** + * Visits a coalesce function. + * + * @param expressions The coalesce expressions to check. + * @return the value from the visitor. + */ + default R visitCoalesce(List expressions) { + return visitLibraryFunction(Coalesce.getDefinition(), expressions); + } + /** * Visits a not function. * @@ -107,6 +118,11 @@ public R visitIsSet(Expression fn) { return getDefault(); } + @Override + public R visitCoalesce(List expressions) { + return getDefault(); + } + @Override public R visitNot(Expression not) { return getDefault(); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Reference.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Reference.java index 1f8a757d38a..e7341dfe745 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Reference.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Reference.java @@ -6,7 +6,9 @@ import static software.amazon.smithy.rulesengine.language.error.RuleError.context; +import java.util.Collections; import java.util.Objects; +import java.util.Set; import software.amazon.smithy.model.FromSourceLocation; import software.amazon.smithy.model.node.Node; import software.amazon.smithy.model.node.ObjectNode; @@ -47,6 +49,11 @@ public String template() { return String.format("{%s}", name); } + @Override + protected Set calculateReferences() { + return Collections.singleton(getName().toString()); + } + @Override public R accept(ExpressionVisitor visitor) { return visitor.visitRef(this); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Template.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Template.java index a634158cd2c..c8cb5931533 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Template.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Template.java @@ -22,7 +22,6 @@ import software.amazon.smithy.rulesengine.language.evaluation.TypeCheck; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.syntax.ToExpression; -import software.amazon.smithy.utils.SmithyBuilder; import software.amazon.smithy.utils.SmithyUnstableApi; /** @@ -41,7 +40,7 @@ public final class Template implements FromSourceLocation, ToNode { private final String value; public Template(StringNode template) { - sourceLocation = SmithyBuilder.requiredState("source", template.getSourceLocation()); + sourceLocation = template.getSourceLocation(); value = template.getValue(); parts = context("when parsing template", template, () -> parseTemplate(template.getValue(), template)); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/BooleanEquals.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/BooleanEquals.java index bfc615a111e..066cf1cd966 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/BooleanEquals.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/BooleanEquals.java @@ -56,6 +56,16 @@ public static BooleanEquals ofExpressions(ToExpression arg1, boolean arg2) { return ofExpressions(arg1, Expression.of(arg2)); } + @Override + public BooleanEquals canonicalize() { + List args = getArguments(); + if (shouldSwapArgs(args.get(0), args.get(1))) { + return BooleanEquals.ofExpressions(args.get(1), args.get(0)); + } + + return this; + } + @Override public R accept(ExpressionVisitor visitor) { return visitor.visitBoolEquals(functionNode.getArguments().get(0), functionNode.getArguments().get(1)); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java new file mode 100644 index 00000000000..ce32f1936b4 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java @@ -0,0 +1,154 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.language.syntax.expressions.functions; + +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import software.amazon.smithy.rulesengine.language.RulesVersion; +import software.amazon.smithy.rulesengine.language.evaluation.Scope; +import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.evaluation.value.Value; +import software.amazon.smithy.rulesengine.language.syntax.ToExpression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.ExpressionVisitor; +import software.amazon.smithy.utils.SmithyUnstableApi; + +/** + * A coalesce function that returns the first non-empty value. + * + *

This variadic function requires two or more arguments. At runtime, returns the first argument that contains a + * non-EmptyValue, otherwise returns the result of the last argument. + * + *

Type checking rules: + *

    + *
  • {@code coalesce(T, T, T) => T} (same types)
  • + *
  • {@code coalesce(Optional, T, T) => T} (any non-optional makes result non-optional)
  • + *
  • {@code coalesce(Optional, Optional, Optional) => Optional} (all optional)
  • + *
+ * + *

Available since: rules engine 1.1. + */ +@SmithyUnstableApi +public final class Coalesce extends LibraryFunction { + public static final String ID = "coalesce"; + private static final Definition DEFINITION = new Definition(); + + private Coalesce(FunctionNode functionNode) { + super(DEFINITION, functionNode); + } + + /** + * Gets the {@link FunctionDefinition} implementation. + * + * @return the function definition. + */ + public static Definition getDefinition() { + return DEFINITION; + } + + /** + * Creates a {@link Coalesce} function from variadic expressions. + * + * @param args the expressions to coalesce + * @return The resulting {@link Coalesce} function. + */ + public static Coalesce ofExpressions(ToExpression... args) { + return DEFINITION.createFunction(FunctionNode.ofExpressions(ID, args)); + } + + /** + * Creates a {@link Coalesce} function from a list of expressions. + * + * @param args the expressions to coalesce + * @return The resulting {@link Coalesce} function. + */ + public static Coalesce ofExpressions(List args) { + return ofExpressions(args.toArray(new ToExpression[0])); + } + + @Override + public RulesVersion availableSince() { + return RulesVersion.V1_1; + } + + @Override + public R accept(ExpressionVisitor visitor) { + return visitor.visitCoalesce(getArguments()); + } + + @Override + public Type typeCheck(Scope scope) { + List args = getArguments(); + if (args.size() < 2) { + throw new IllegalArgumentException("Coalesce requires at least 2 arguments, got " + args.size()); + } + + // Get the first argument's type as the baseline + Type firstType = args.get(0).typeCheck(scope); + Type baseInnerType = getInnerType(firstType); + boolean hasNonOptional = !(firstType instanceof OptionalType); + + // Check all other arguments match the base type + for (int i = 1; i < args.size(); i++) { + Type argType = args.get(i).typeCheck(scope); + Type innerType = getInnerType(argType); + + if (!innerType.equals(baseInnerType)) { + throw new IllegalArgumentException(String.format( + "Type mismatch in coalesce at argument %d: expected %s but got %s", + i + 1, + baseInnerType, + innerType)); + } + + hasNonOptional = hasNonOptional || !(argType instanceof OptionalType); + } + + return hasNonOptional ? baseInnerType : Type.optionalType(baseInnerType); + } + + private static Type getInnerType(Type t) { + return (t instanceof OptionalType) ? ((OptionalType) t).inner() : t; + } + + /** + * A {@link FunctionDefinition} for the {@link Coalesce} function. + */ + public static final class Definition implements FunctionDefinition { + private Definition() {} + + @Override + public String getId() { + return ID; + } + + @Override + public List getArguments() { + return Collections.emptyList(); + } + + @Override + public Optional getVariadicArguments() { + return Optional.of(Type.anyType()); + } + + @Override + public Type getReturnType() { + return Type.anyType(); + } + + @Override + public Value evaluate(List arguments) { + throw new UnsupportedOperationException("Coalesce evaluation is handled by ExpressionVisitor"); + } + + @Override + public Coalesce createFunction(FunctionNode functionNode) { + return new Coalesce(functionNode); + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionDefinition.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionDefinition.java index e8878c85ed6..c07cad7b7fe 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionDefinition.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionDefinition.java @@ -5,6 +5,7 @@ package software.amazon.smithy.rulesengine.language.syntax.expressions.functions; import java.util.List; +import java.util.Optional; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.evaluation.value.Value; import software.amazon.smithy.utils.SmithyUnstableApi; @@ -26,6 +27,18 @@ public interface FunctionDefinition { */ List getArguments(); + /** + * Gets the type of variadic arguments if this function accepts them. + * + *

When present, the function accepts any number of additional arguments of this type after the fixed arguments + * from getArguments(). + * + * @return the variadic argument type, or empty if not variadic + */ + default Optional getVariadicArguments() { + return Optional.empty(); + } + /** * The return type of this function definition. * @return The function return type diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionNode.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionNode.java index c3b98048ccf..51a40498bae 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionNode.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionNode.java @@ -109,7 +109,7 @@ public static FunctionNode fromNode(ObjectNode function) { * * @return this function as an expression. */ - public Expression createFunction() { + public LibraryFunction createFunction() { return EndpointRuleSet.createFunctionFactory() .apply(this) .orElseThrow(() -> new RuleError(new SourceException( diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java index afddaec181f..c5daeab89d1 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java @@ -5,8 +5,11 @@ package software.amazon.smithy.rulesengine.language.syntax.expressions.functions; import java.util.ArrayList; +import java.util.LinkedHashSet; import java.util.List; import java.util.Objects; +import java.util.Optional; +import java.util.Set; import software.amazon.smithy.model.SourceException; import software.amazon.smithy.model.SourceLocation; import software.amazon.smithy.model.node.Node; @@ -15,6 +18,10 @@ import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.utils.SmithyUnstableApi; import software.amazon.smithy.utils.StringUtils; @@ -41,6 +48,24 @@ public String getName() { return functionNode.getName(); } + @Override + protected Set calculateReferences() { + Set references = new LinkedHashSet<>(); + for (Expression arg : getArguments()) { + references.addAll(arg.getReferences()); + } + return references; + } + + /** + * Get the function definition. + * + * @return function definition. + */ + public FunctionDefinition getFunctionDefinition() { + return definition; + } + /** * @return The arguments to this function */ @@ -48,6 +73,17 @@ public List getArguments() { return functionNode.getArguments(); } + /** + * Returns a canonical form of this function. + * + *

Default implementation returns this. Override for functions that need canonicalization. + * + * @return the canonical form of this function + */ + public LibraryFunction canonicalize() { + return this; + } + protected Expression expectOneArgument() { List argv = functionNode.getArguments(); if (argv.size() == 1) { @@ -56,6 +92,11 @@ protected Expression expectOneArgument() { throw new RuleError(new SourceException("expected 1 argument but found " + argv.size(), functionNode)); } + @Override + public Condition.Builder toConditionBuilder() { + return Condition.builder().fn(this); + } + @Override public SourceLocation getSourceLocation() { return functionNode.getSourceLocation(); @@ -65,7 +106,7 @@ public SourceLocation getSourceLocation() { protected Type typeCheckLocal(Scope scope) { RuleError.context(String.format("while typechecking the invocation of %s", definition.getId()), this, () -> { try { - checkTypeSignature(definition.getArguments(), functionNode.getArguments(), scope); + checkTypeSignature(scope); } catch (InnerParseError e) { throw new RuntimeException(e.getMessage()); } @@ -73,37 +114,64 @@ protected Type typeCheckLocal(Scope scope) { return definition.getReturnType(); } - private void checkTypeSignature(List expectedArgs, List actualArguments, Scope scope) + private void checkTypeSignature(Scope scope) throws InnerParseError { + List expectedArgs = definition.getArguments(); + Optional variadicType = definition.getVariadicArguments(); + List actualArguments = functionNode.getArguments(); + + if (variadicType.isPresent()) { + // check we have at least the fixed arguments + if (actualArguments.size() < expectedArgs.size()) { + throw new InnerParseError(String.format("Expected at least %s arguments but found %s", + expectedArgs.size(), + actualArguments.size())); + } + // check fixed arguments + for (int i = 0; i < expectedArgs.size(); i++) { + checkArgument(i, expectedArgs.get(i), actualArguments.get(i), scope); + } + // check variadic arguments + Type varType = variadicType.get(); + for (int i = expectedArgs.size(); i < actualArguments.size(); i++) { + checkArgument(i, varType, actualArguments.get(i), scope); + } + } else { + // Non-variadic, so exact count required + if (expectedArgs.size() != actualArguments.size()) { + throw new InnerParseError(String.format("Expected %s arguments but found %s", + expectedArgs.size(), + actualArguments.size())); + } + // check all positional arguments + for (int i = 0; i < expectedArgs.size(); i++) { + checkArgument(i, expectedArgs.get(i), actualArguments.get(i), scope); + } + } + } + + private void checkArgument(int index, Type expected, Expression actual, Scope scope) throws InnerParseError { - if (expectedArgs.size() != actualArguments.size()) { - throw new InnerParseError( - String.format( - "Expected %s arguments but found %s", - expectedArgs.size(), - actualArguments)); + Type actualType = actual.typeCheck(scope); + if (expected.isA(actualType)) { + return; } - for (int i = 0; i < expectedArgs.size(); i++) { - Type expected = expectedArgs.get(i); - Type actual = actualArguments.get(i).typeCheck(scope); - if (!expected.isA(actual)) { - Type optAny = Type.optionalType(Type.anyType()); - String hint = ""; - if (actual.isA(optAny) && !expected.isA(optAny) - && actual.expectOptionalType().inner().equals(expected)) { - hint = String.format( - "hint: use `assign` in a condition or `isSet(%s)` to prove that this value is non-null", - actualArguments.get(i)); - hint = StringUtils.indent(hint, 2); - } - throw new InnerParseError( - String.format( - "Unexpected type in the %s argument: Expected %s but found %s%n%s", - ordinal(i + 1), - expected, - actual, - hint)); - } + + Type optAny = Type.optionalType(Type.anyType()); + String hint = ""; + if (actualType.isA(optAny) + && !expected.isA(optAny) + && actualType.expectOptionalType().inner().equals(expected)) { + hint = String.format( + "hint: use `assign` in a condition or `isSet(%s)` to prove that this value is non-null", + actual); + hint = StringUtils.indent(hint, 2); } + + throw new InnerParseError(String.format("Unexpected type in the %s argument: Expected %s but found %s%n%s", + ordinal(index + 1), + expected, + actualType, + hint)); } private static String ordinal(int arg) { @@ -153,4 +221,56 @@ public String toString() { } return getName() + "(" + String.join(", ", arguments) + ")"; } + + /** + * Determines if two arguments should be swapped for canonical ordering. + * Used by commutative functions to ensure consistent argument order. + * + * @param arg0 the first argument + * @param arg1 the second argument + * @return true if arguments should be swapped + */ + protected static boolean shouldSwapArgs(Expression arg0, Expression arg1) { + boolean arg0IsRef = isReference(arg0); + boolean arg1IsRef = isReference(arg1); + + // Always put References before literals to make things consistent + if (arg0IsRef != arg1IsRef) { + return !arg0IsRef; // Swap if arg0 is literal and arg1 is reference + } + + // Both same type, use string comparison for deterministic order + return arg0.toString().compareTo(arg1.toString()) > 0; + } + + /** + * Strips single-variable template wrappers if present. + * Converts "{varName}" to just varName reference. + * + * @param expr the expression to strip + * @return the stripped expression or original if not applicable + */ + static Expression stripSingleVariableTemplate(Expression expr) { + if (!(expr instanceof StringLiteral)) { + return expr; + } + + StringLiteral stringLit = (StringLiteral) expr; + List parts = stringLit.value().getParts(); + if (parts.size() == 1 && parts.get(0) instanceof Template.Dynamic) { + return ((Template.Dynamic) parts.get(0)).toExpression(); + } + + return expr; + } + + private static boolean isReference(Expression arg) { + if (arg instanceof Reference) { + return true; + } else if (arg instanceof StringLiteral) { + StringLiteral s = (StringLiteral) arg; + return !s.value().isStatic(); + } + return false; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/StringEquals.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/StringEquals.java index 1c0f228d75d..8892e3ca32f 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/StringEquals.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/StringEquals.java @@ -56,6 +56,25 @@ public static StringEquals ofExpressions(ToExpression arg1, String arg2) { return ofExpressions(arg1, Expression.of(arg2)); } + @Override + public StringEquals canonicalize() { + List args = getArguments(); + + // Strip single-variable templates + Expression arg0 = stripSingleVariableTemplate(args.get(0)); + Expression arg1 = stripSingleVariableTemplate(args.get(1)); + + // Check if we need to reorder for commutative canonicalization + if (shouldSwapArgs(arg0, arg1)) { + return StringEquals.ofExpressions(arg1, arg0); + } else if (arg0 != args.get(0) || arg1 != args.get(1)) { + // Templates were stripped but no reordering needed + return StringEquals.ofExpressions(arg0, arg1); + } + + return this; + } + @Override public R accept(ExpressionVisitor visitor) { return visitor.visitStringEquals(functionNode.getArguments().get(0), functionNode.getArguments().get(1)); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Substring.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Substring.java index 8285b833799..70d5d41194c 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Substring.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Substring.java @@ -117,22 +117,23 @@ public Substring createFunction(FunctionNode functionNode) { * @return the substring value or null. */ public static String getSubstring(String value, int startIndex, int stopIndex, boolean reverse) { - if (startIndex >= stopIndex || value.length() < stopIndex) { + if (value == null) { return null; } - for (int i = 0; i < value.length(); i++) { - if (!(value.charAt(i) <= 127)) { + int len = value.length(); + if (startIndex < 0 || stopIndex > len || startIndex >= stopIndex) { + return null; + } + + int from = reverse ? len - stopIndex : startIndex; + int to = reverse ? len - startIndex : stopIndex; + for (int i = from; i < to; i++) { + if (value.charAt(i) > 127) { return null; } } - if (!reverse) { - return value.substring(startIndex, stopIndex); - } else { - int revStart = value.length() - stopIndex; - int revStop = value.length() - startIndex; - return value.substring(revStart, revStop); - } + return value.substring(from, to); } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/RecordLiteral.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/RecordLiteral.java index a3357a612d4..01454ef1724 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/RecordLiteral.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/RecordLiteral.java @@ -6,9 +6,11 @@ import java.util.Collections; import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; import software.amazon.smithy.model.FromSourceLocation; import software.amazon.smithy.model.node.Node; import software.amazon.smithy.model.node.ObjectNode; @@ -72,4 +74,13 @@ public Node toNode() { members.forEach((k, v) -> builder.withMember(k.toString(), v.toNode())); return builder.build(); } + + @Override + protected Set calculateReferences() { + Set references = new LinkedHashSet<>(); + for (Literal value : members().values()) { + references.addAll(value.getReferences()); + } + return references; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/StringLiteral.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/StringLiteral.java index 29dce7b61fe..9e9326f34af 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/StringLiteral.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/StringLiteral.java @@ -4,8 +4,11 @@ */ package software.amazon.smithy.rulesengine.language.syntax.expressions.literal; +import java.util.Collections; +import java.util.LinkedHashSet; import java.util.Objects; import java.util.Optional; +import java.util.Set; import software.amazon.smithy.model.FromSourceLocation; import software.amazon.smithy.model.node.Node; import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; @@ -66,4 +69,21 @@ public String toString() { public Node toNode() { return value.toNode(); } + + @Override + protected Set calculateReferences() { + Template template = value(); + if (template.isStatic()) { + return Collections.emptySet(); + } + + Set references = new LinkedHashSet<>(); + for (Template.Part part : template.getParts()) { + if (part instanceof Template.Dynamic) { + references.addAll(((Template.Dynamic) part).toExpression().getReferences()); + } + } + + return references; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/TupleLiteral.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/TupleLiteral.java index 095b34bb182..71672ce554b 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/TupleLiteral.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/TupleLiteral.java @@ -5,9 +5,11 @@ package software.amazon.smithy.rulesengine.language.syntax.expressions.literal; import java.util.ArrayList; +import java.util.LinkedHashSet; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; import software.amazon.smithy.model.FromSourceLocation; import software.amazon.smithy.model.node.ArrayNode; import software.amazon.smithy.model.node.Node; @@ -78,4 +80,13 @@ public Node toNode() { } return builder.build(); } + + @Override + protected Set calculateReferences() { + Set references = new LinkedHashSet<>(); + for (Literal member : members()) { + references.addAll(member.getReferences()); + } + return references; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameter.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameter.java index eaff721790b..1a36efb4bcf 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameter.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameter.java @@ -24,6 +24,7 @@ import software.amazon.smithy.utils.ListUtils; import software.amazon.smithy.utils.SmithyBuilder; import software.amazon.smithy.utils.SmithyUnstableApi; +import software.amazon.smithy.utils.StringUtils; import software.amazon.smithy.utils.ToSmithyBuilder; /** @@ -234,7 +235,7 @@ public Optional getDefault() { @Override public Condition.Builder toConditionBuilder() { - return Condition.builder().fn(toExpression()); + throw new UnsupportedOperationException("Cannot convert a Parameter to a Condition"); } @Override @@ -272,7 +273,7 @@ public Node toNode() { if (documentation != null) { node.withMember(DOCUMENTATION, documentation); } - node.withMember(TYPE, type.toString()); + node.withMember(TYPE, StringUtils.uncapitalize(type.toString())); return node.build(); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameters.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameters.java index 95feeedbf54..352c1d8d447 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameters.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameters.java @@ -4,6 +4,7 @@ */ package software.amazon.smithy.rulesengine.language.syntax.parameters; +import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -57,8 +58,7 @@ public static Builder builder() { public static Parameters fromNode(ObjectNode node) throws RuleError { Builder builder = new Builder(node); for (Map.Entry entry : node.getMembers().entrySet()) { - builder.addParameter(Parameter.fromNode(entry.getKey(), - RuleError.context("when parsing parameter", () -> entry.getValue().expectObjectNode()))); + builder.addParameter(Parameter.fromNode(entry.getKey(), entry.getValue().expectObjectNode())); } return builder.build(); } @@ -76,6 +76,15 @@ public void writeToScope(Scope scope) { } } + /** + * Convert the Parameters container to a list. + * + * @return the parameters list. + */ + public List toList() { + return Collections.unmodifiableList(parameters); + } + @Override public SourceLocation getSourceLocation() { return sourceLocation; diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Condition.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Condition.java index 4f331600381..d98b1ee21d5 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Condition.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Condition.java @@ -21,6 +21,7 @@ import software.amazon.smithy.rulesengine.language.syntax.SyntaxElement; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionNode; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; import software.amazon.smithy.utils.SmithyBuilder; import software.amazon.smithy.utils.SmithyUnstableApi; @@ -31,8 +32,10 @@ @SmithyUnstableApi public final class Condition extends SyntaxElement implements TypeCheck, FromSourceLocation, ToNode { public static final String ASSIGN = "assign"; - private final Expression function; + private final LibraryFunction function; private final Identifier result; + private int hash; + private String toString; private Condition(Builder builder) { this.result = builder.result; @@ -86,10 +89,23 @@ public Optional getResult() { * * @return the function for this condition. */ - public Expression getFunction() { + public LibraryFunction getFunction() { return function; } + /** + * Returns a canonical form of this condition. + * + * @return the canonical condition + */ + public Condition canonicalize() { + LibraryFunction canonicalFn = function.canonicalize(); + if (canonicalFn == function) { + return this; + } + return toBuilder().fn(canonicalFn).build(); + } + @Override public Builder toConditionBuilder() { return toBuilder(); @@ -152,26 +168,32 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(function, result); + int h = hash; + if (h == 0) { + h = Objects.hash(function, result); + hash = h; + } + return h; } @Override public String toString() { - StringBuilder sb = new StringBuilder(); - if (result != null) { - sb.append(result).append(" = "); + String s = this.toString; + if (s == null) { + s = result != null ? (result + " = " + function) : function.toString(); + toString = s; } - return sb.append(function).toString(); + return s; } /** * A builder used to create a {@link Condition} class. */ public static class Builder implements SmithyBuilder { - private Expression fn; + private LibraryFunction fn; private Identifier result; - public Builder fn(Expression fn) { + public Builder fn(LibraryFunction fn) { this.fn = fn; return this; } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/EndpointRule.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/EndpointRule.java index f20c483fdf0..fddbe5f38b7 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/EndpointRule.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/EndpointRule.java @@ -20,6 +20,7 @@ @SmithyUnstableApi public final class EndpointRule extends Rule { private final Endpoint endpoint; + private int hash; EndpointRule(Rule.Builder builder, Endpoint endpoint) { super(builder); @@ -46,7 +47,7 @@ protected Type typecheckValue(Scope scope) { } @Override - void withValueNode(ObjectNode.Builder builder) { + protected void withValueNode(ObjectNode.Builder builder) { builder.withMember("endpoint", endpoint).withMember(TYPE, ENDPOINT); } @@ -67,7 +68,12 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(super.hashCode(), endpoint); + int result = hash; + if (result == 0) { + result = Objects.hash(super.hashCode(), endpoint); + hash = result; + } + return hash; } @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/ErrorRule.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/ErrorRule.java index 03a60388f49..9dfce50faf6 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/ErrorRule.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/ErrorRule.java @@ -6,6 +6,7 @@ import static software.amazon.smithy.rulesengine.language.error.RuleError.context; +import java.util.Objects; import software.amazon.smithy.model.node.ObjectNode; import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; @@ -19,6 +20,7 @@ @SmithyUnstableApi public final class ErrorRule extends Rule { private final Expression error; + private int hash; public ErrorRule(Rule.Builder builder, Expression error) { super(builder); @@ -45,7 +47,7 @@ protected Type typecheckValue(Scope scope) { } @Override - void withValueNode(ObjectNode.Builder builder) { + protected void withValueNode(ObjectNode.Builder builder) { builder.withMember("error", error.toNode()).withMember(TYPE, ERROR); } @@ -54,4 +56,28 @@ public String toString() { return super.toString() + StringUtils.indent(String.format("error(%s)", error), 2); } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } else if (object == null || getClass() != object.getClass()) { + return false; + } else if (!super.equals(object)) { + return false; + } else { + ErrorRule errorRule = (ErrorRule) object; + return Objects.equals(error, errorRule.error); + } + } + + @Override + public int hashCode() { + int result = hash; + if (result == 0) { + result = Objects.hash(super.hashCode(), error); + hash = result; + } + return hash; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java new file mode 100644 index 00000000000..d7c76f7feec --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java @@ -0,0 +1,43 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.language.syntax.rule; + +import software.amazon.smithy.model.node.ObjectNode; +import software.amazon.smithy.rulesengine.language.evaluation.Scope; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.utils.SmithyUnstableApi; + +/** + * Sentinel rule for "no match" results. + */ +@SmithyUnstableApi +public final class NoMatchRule extends Rule { + + public static final NoMatchRule INSTANCE = new NoMatchRule(); + + private NoMatchRule() { + super(Rule.builder()); + } + + @Override + public T accept(RuleValueVisitor visitor) { + throw new UnsupportedOperationException("NO_MATCH is a sentinel"); + } + + @Override + protected Type typecheckValue(Scope scope) { + throw new UnsupportedOperationException("NO_MATCH is a sentinel"); + } + + @Override + protected void withValueNode(ObjectNode.Builder builder) { + // nothing + } + + @Override + public String toString() { + return "NO_MATCH"; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Rule.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Rule.java index 5a66d029e09..33b1efa5090 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Rule.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Rule.java @@ -84,7 +84,10 @@ public static Rule fromNode(Node node) { Builder builder = new Builder(node); objectNode.getStringMember(DOCUMENTATION, builder::description); - builder.conditions(objectNode.expectArrayMember(CONDITIONS).getElementsAs(Condition::fromNode)); + + objectNode.getArrayMember(CONDITIONS).ifPresent(conds -> { + builder.conditions(conds.getElementsAs(Condition::fromNode)); + }); String type = objectNode.expectStringMember(TYPE).getValue(); switch (type) { @@ -131,9 +134,31 @@ public Optional getDocumentation() { */ public abstract T accept(RuleValueVisitor visitor); + /** + * Get a new Rule of the same type that has the same values, but with the given conditions. + * + * @param conditions Conditions to use. + * @return the rule with the given conditions. + * @throws UnsupportedOperationException if it is a TreeRule or Condition rule. + */ + public final Rule withConditions(List conditions) { + if (getConditions().equals(conditions)) { + return this; + } else if (this instanceof ErrorRule) { + return new ErrorRule(ErrorRule.builder(this).conditions(conditions), ((ErrorRule) this).getError()); + } else if (this instanceof EndpointRule) { + return new EndpointRule(EndpointRule.builder(this).conditions(conditions), + ((EndpointRule) this).getEndpoint()); + } else if (this instanceof TreeRule) { + return new TreeRule(TreeRule.builder(this).conditions(conditions), ((TreeRule) this).getRules()); + } else { + throw new UnsupportedOperationException("Unknown rule type: " + this.getClass()); + } + } + protected abstract Type typecheckValue(Scope scope); - abstract void withValueNode(ObjectNode.Builder builder); + protected abstract void withValueNode(ObjectNode.Builder builder); @Override public Type typeCheck(Scope scope) { @@ -150,16 +175,17 @@ public Type typeCheck(Scope scope) { public Node toNode() { ObjectNode.Builder builder = ObjectNode.builder(); + if (documentation != null) { + builder.withMember(DOCUMENTATION, documentation); + } + + // TODO: we should remove the requirement of serializing an empty array here ArrayNode.Builder conditionsBuilder = ArrayNode.builder(); for (Condition condition : conditions) { conditionsBuilder.withValue(condition.toNode()); } builder.withMember(CONDITIONS, conditionsBuilder.build()); - if (documentation != null) { - builder.withMember(DOCUMENTATION, documentation); - } - withValueNode(builder); return builder.build(); } @@ -219,7 +245,7 @@ public Builder conditions(ToCondition... conditions) { return this; } - public Builder conditions(List conditions) { + public Builder conditions(List conditions) { this.conditions.addAll(conditions); return this; } @@ -229,20 +255,24 @@ public Builder condition(ToCondition condition) { return this; } - public Rule endpoint(Endpoint endpoint) { - return this.onBuild.apply(new EndpointRule(this, endpoint)); + public EndpointRule endpoint(Endpoint endpoint) { + return (EndpointRule) this.onBuild.apply(new EndpointRule(this, endpoint)); + } + + public ErrorRule error(Node error) { + return error(Expression.fromNode(error)); } - public Rule error(Node error) { - return this.onBuild.apply(new ErrorRule(this, Expression.fromNode(error))); + public ErrorRule error(String error) { + return error(Literal.of(error)); } - public Rule error(String error) { - return this.onBuild.apply(new ErrorRule(this, Literal.of(error))); + public ErrorRule error(Expression error) { + return (ErrorRule) this.onBuild.apply(new ErrorRule(this, error)); } - public Rule treeRule(Rule... rules) { - return this.treeRule(Arrays.asList(rules)); + public TreeRule treeRule(Rule... rules) { + return (TreeRule) this.treeRule(Arrays.asList(rules)); } @SafeVarargs diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/TreeRule.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/TreeRule.java index d6a817cd3b4..2afe672cde6 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/TreeRule.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/TreeRule.java @@ -21,6 +21,7 @@ @SmithyUnstableApi public final class TreeRule extends Rule { private final List rules; + private int hash; TreeRule(Builder builder, List rules) { super(builder); @@ -54,7 +55,7 @@ protected Type typecheckValue(Scope scope) { } @Override - void withValueNode(ObjectNode.Builder builder) { + protected void withValueNode(ObjectNode.Builder builder) { ArrayNode.Builder rulesBuilder = ArrayNode.builder().sourceLocation(getSourceLocation()); for (Rule rule : rules) { rulesBuilder.withValue(rule.toNode()); @@ -70,4 +71,14 @@ public String toString() { } return super.toString() + StringUtils.indent(String.join("\n", ruleStrings), 2); } + + @Override + public int hashCode() { + int result = hash; + if (result == 0) { + result = super.hashCode(); + hash = result; + } + return result; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionEvaluator.java new file mode 100644 index 00000000000..dd4c1d06832 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionEvaluator.java @@ -0,0 +1,22 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic; + +/** + * Evaluates a single condition using a condition index. + * + *

This functional interface provides maximum flexibility for condition evaluation implementations. Implementations + * are responsible maintaining their own internal state as methods are called (e.g., tracking variables). + */ +@FunctionalInterface +public interface ConditionEvaluator { + /** + * Evaluates the condition at the given index. + * + * @param conditionIndex the index of the condition to evaluate + * @return true if the condition is satisfied, false otherwise + */ + boolean test(int conditionIndex); +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionReference.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionReference.java new file mode 100644 index 00000000000..c41c1d76e71 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionReference.java @@ -0,0 +1,69 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic; + +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; + +/** + * A reference to a condition and whether it is negated. + */ +public final class ConditionReference { + + private final Condition condition; + private final boolean negated; + + public ConditionReference(Condition condition, boolean negated) { + this.condition = condition; + this.negated = negated; + } + + /** + * Returns true if this condition is negated (e.g., wrapped in not). + * + * @return true if negated. + */ + public boolean isNegated() { + return negated; + } + + /** + * Create a negated version of this reference. + * + * @return returns the negated reference. + */ + public ConditionReference negate() { + return new ConditionReference(condition, !negated); + } + + /** + * Get the underlying condition. + * + * @return condition. + */ + public Condition getCondition() { + return condition; + } + + @Override + public String toString() { + return (negated ? "!" : "") + condition.toString(); + } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } else if (object == null || getClass() != object.getClass()) { + return false; + } + ConditionReference that = (ConditionReference) object; + return negated == that.negated && condition.equals(that.condition); + } + + @Override + public int hashCode() { + return condition.hashCode() ^ (negated ? 0x80000000 : 0); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/RuleBasedConditionEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/RuleBasedConditionEvaluator.java new file mode 100644 index 00000000000..008de36eafb --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/RuleBasedConditionEvaluator.java @@ -0,0 +1,27 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic; + +import software.amazon.smithy.rulesengine.language.evaluation.RuleEvaluator; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; + +/** + * Evaluates rules using a rules engine evaluator. + */ +public final class RuleBasedConditionEvaluator implements ConditionEvaluator { + private final RuleEvaluator evaluator; + private final Condition[] conditions; + + public RuleBasedConditionEvaluator(RuleEvaluator evaluator, Condition[] conditions) { + this.evaluator = evaluator; + this.conditions = conditions; + } + + @Override + public boolean test(int conditionIndex) { + Condition condition = conditions[conditionIndex]; + return evaluator.evaluateCondition(condition).isTruthy(); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java new file mode 100644 index 00000000000..89b209cb9c5 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java @@ -0,0 +1,318 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.UncheckedIOException; +import java.io.Writer; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.function.Consumer; +import software.amazon.smithy.rulesengine.logic.ConditionEvaluator; + +/** + * Binary Decision Diagram (BDD) with complement edges for efficient rule evaluation. + * + *

This class represents a pure BDD structure without any knowledge of the specific + * conditions or results it represents. The interpretation of condition indices and + * result indices is left to the caller. + * + *

Reference Encoding: + *

    + *
  • {@code 0}: Invalid/unused reference (never appears in valid BDDs)
  • + *
  • {@code 1}: TRUE terminal
  • + *
  • {@code -1}: FALSE terminal
  • + *
  • {@code 2, 3, ...}: Node references (points to nodes array at index ref-1)
  • + *
  • {@code -2, -3, ...}: Complement node references (logical NOT)
  • + *
  • {@code 100_000_000+}: Result terminals (100_000_000 + resultIndex)
  • + *
+ */ +public final class Bdd { + /** + * Result reference encoding offset. + */ + public static final int RESULT_OFFSET = 100_000_000; + + private final int[] nodes; // Flat array: [var0, high0, low0, var1, high1, low1, ...] + private final int rootRef; + private final int conditionCount; + private final int resultCount; + private final int nodeCount; + + /** + * Creates a BDD by streaming nodes directly into the structure. + * + * @param rootRef the root reference + * @param conditionCount the number of conditions + * @param resultCount the number of results + * @param nodeCount the exact number of nodes + * @param nodeHandler a handler that will provide nodes via a consumer + */ + public Bdd(int rootRef, int conditionCount, int resultCount, int nodeCount, Consumer nodeHandler) { + validateCounts(conditionCount, resultCount, nodeCount); + validateRootReference(rootRef, nodeCount); + + this.rootRef = rootRef; + this.conditionCount = conditionCount; + this.resultCount = resultCount; + this.nodeCount = nodeCount; + + InputNodeConsumer consumer = new InputNodeConsumer(nodeCount); + nodeHandler.accept(consumer); + this.nodes = consumer.nodes; + + if (consumer.index != nodeCount * 3) { + throw new IllegalStateException("Expected " + nodeCount + " nodes, but got " + (consumer.index / 3)); + } + } + + /** + * Creates a BDD by streaming nodes directly into the structure. + * + * @param rootRef the root reference + * @param conditionCount the number of conditions + * @param resultCount the number of results + * @param nodeCount the exact number of nodes + * @param nodes BDD nodes array where the condition, high, and low are all in succession. + */ + public Bdd(int rootRef, int conditionCount, int resultCount, int nodeCount, int[] nodes) { + validateCounts(conditionCount, resultCount, nodeCount); + validateRootReference(rootRef, nodeCount); + + if (nodes.length != nodeCount * 3) { + throw new IllegalArgumentException("Nodes array length must be nodeCount * 3"); + } + + this.rootRef = rootRef; + this.conditionCount = conditionCount; + this.resultCount = resultCount; + this.nodeCount = nodeCount; + this.nodes = nodes; + } + + private static void validateCounts(int conditionCount, int resultCount, int nodeCount) { + if (conditionCount < 0) { + throw new IllegalArgumentException("Condition count cannot be negative: " + conditionCount); + } else if (resultCount < 0) { + throw new IllegalArgumentException("Result count cannot be negative: " + resultCount); + } else if (nodeCount < 0) { + throw new IllegalArgumentException("Node count cannot be negative: " + nodeCount); + } + } + + private static void validateRootReference(int rootRef, int nodeCount) { + if (isComplemented(rootRef) && !isTerminal(rootRef)) { + throw new IllegalArgumentException("Root reference cannot be complemented: " + rootRef); + } else if (isNodeReference(rootRef)) { + int idx = Math.abs(rootRef) - 1; + if (idx >= nodeCount) { + throw new IllegalArgumentException("Root points to invalid BDD node: " + idx + + " (node count: " + nodeCount + ")"); + } + } + } + + private static final class InputNodeConsumer implements BddNodeConsumer { + private int index = 0; + private final int[] nodes; + + private InputNodeConsumer(int nodeCount) { + this.nodes = new int[nodeCount * 3]; + } + + @Override + public void accept(int var, int high, int low) { + nodes[index++] = var; + nodes[index++] = high; + nodes[index++] = low; + } + } + + /** + * Gets the number of conditions. + * + * @return condition count + */ + public int getConditionCount() { + return conditionCount; + } + + /** + * Gets the number of results. + * + * @return result count + */ + public int getResultCount() { + return resultCount; + } + + /** + * Gets the number of nodes in the BDD. + * + * @return the node count + */ + public int getNodeCount() { + return nodeCount; + } + + /** + * Gets the root node reference. + * + * @return root reference + */ + public int getRootRef() { + return rootRef; + } + + /** + * Gets the variable index for a node. + * + * @param nodeIndex the node index (0-based) + * @return the variable index + */ + public int getVariable(int nodeIndex) { + validateRange(nodeIndex); + return nodes[nodeIndex * 3]; + } + + private void validateRange(int index) { + if (index < 0 || index >= nodeCount) { + throw new IndexOutOfBoundsException("Node index out of bounds: " + index + " (size: " + nodeCount + ")"); + } + } + + /** + * Gets the high (true) reference for a node. + * + * @param nodeIndex the node index (0-based) + * @return the high reference + */ + public int getHigh(int nodeIndex) { + validateRange(nodeIndex); + return nodes[nodeIndex * 3 + 1]; + } + + /** + * Gets the low (false) reference for a node. + * + * @param nodeIndex the node index (0-based) + * @return the low reference + */ + public int getLow(int nodeIndex) { + validateRange(nodeIndex); + return nodes[nodeIndex * 3 + 2]; + } + + /** + * Write all nodes to the consumer. + * + * @param consumer the consumer to receive the integers + */ + public void getNodes(BddNodeConsumer consumer) { + for (int i = 0; i < nodeCount; i++) { + int base = i * 3; + consumer.accept(nodes[base], nodes[base + 1], nodes[base + 2]); + } + } + + /** + * Evaluates the BDD using the provided condition evaluator. + * + * @param ev the condition evaluator + * @return the result index, or -1 for no match + */ + public int evaluate(ConditionEvaluator ev) { + int ref = rootRef; + int[] n = this.nodes; + + while (isNodeReference(ref)) { + int idx = ref > 0 ? ref - 1 : -ref - 1; + int base = idx * 3; + // test ^ complement, pick hi or lo + ref = (ev.test(n[base]) ^ (ref < 0)) ? n[base + 1] : n[base + 2]; + } + + return isTerminal(ref) ? -1 : ref - RESULT_OFFSET; + } + + /** + * Checks if a reference points to a node (not a terminal or result). + * + * @param ref the reference to check + * @return true if this is a node reference + */ + public static boolean isNodeReference(int ref) { + return (ref > 1 && ref < RESULT_OFFSET) || (ref < -1 && ref > -RESULT_OFFSET); + } + + /** + * Checks if a reference points to a result. + * + * @param ref the reference to check + * @return true if this is a result reference + */ + public static boolean isResultReference(int ref) { + return ref >= RESULT_OFFSET; + } + + /** + * Checks if a reference is a terminal (TRUE or FALSE). + * + * @param ref the reference to check + * @return true if this is a terminal reference + */ + public static boolean isTerminal(int ref) { + return ref == 1 || ref == -1; + } + + /** + * Checks if a reference is complemented (negative). + * + * @param ref the reference to check + * @return true if the reference is complemented + */ + public static boolean isComplemented(int ref) { + return ref < 0 && ref != -1; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } else if (!(obj instanceof Bdd)) { + return false; + } + + Bdd other = (Bdd) obj; + if (rootRef != other.rootRef + || conditionCount != other.conditionCount + || resultCount != other.resultCount + || nodeCount != other.nodeCount) { + return false; + } + + return Arrays.equals(nodes, other.nodes); + } + + @Override + public int hashCode() { + return 31 * rootRef + nodeCount + Arrays.hashCode(nodes); + } + + @Override + public String toString() { + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + Writer writer = new OutputStreamWriter(baos, StandardCharsets.UTF_8); + new BddFormatter(this, writer, "").format(); + writer.flush(); + return baos.toString(StandardCharsets.UTF_8.name()); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java new file mode 100644 index 00000000000..a3d4398d728 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java @@ -0,0 +1,564 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.Arrays; + +/** + * Binary Decision Diagram (BDD) builder with complement edges and multi-terminal support. + * + *

This implementation uses CUDD-style complement edges where negative references + * represent logical negation. The engine supports both boolean operations and + * multi-terminal decision diagrams (MTBDDs) for endpoint resolution. + * + *

Reference encoding: + *

    + *
  • 0: Invalid/unused reference (never appears in valid BDDs)
  • + *
  • 1: TRUE terminal
  • + *
  • -1: FALSE terminal
  • + *
  • 2, 3, 4, ...: BDD nodes (use index + 1)
  • + *
  • -2, -3, -4, ...: Complement of BDD nodes
  • + *
  • Bdd.RESULT_OFFSET+: Result terminals (100_000_000 + resultIndex)
  • + *
+ */ +final class BddBuilder { + + // Terminal constants + private static final int TRUE_REF = 1; + private static final int FALSE_REF = -1; + + // Node storage: flat array [var0, high0, low0, var1, high1, low1, ...] + private static final int INITIAL_SIZE = 256 * 3; + private int[] nodes = new int[INITIAL_SIZE]; + private int nodeCount; + + // Unique tables for node deduplication and ITE caching + private final UniqueTable uniqueTable; + private final UniqueTable iteCache; + + // Track the boundary between conditions and results + private int conditionCount = -1; + + /** + * Creates a new BDD engine. + */ + public BddBuilder() { + this.nodeCount = 1; + this.uniqueTable = new UniqueTable(); + this.iteCache = new UniqueTable(1024); + initializeTerminalNode(); + } + + int getNodeCount() { + return nodeCount; + } + + /** + * Sets the number of conditions. Must be called before creating result nodes. + * + * @param count the number of conditions + */ + public void setConditionCount(int count) { + if (conditionCount != -1) { + throw new IllegalStateException("Condition count already set"); + } + this.conditionCount = count; + } + + /** + * Returns the TRUE terminal reference. + * + * @return TRUE reference (always 1) + */ + public int makeTrue() { + return TRUE_REF; + } + + /** + * Returns the FALSE terminal reference. + * + * @return FALSE reference (always -1) + */ + public int makeFalse() { + return FALSE_REF; + } + + /** + * Creates a result terminal reference. + * + * @param resultIndex the result index (must be non-negative) + * @return reference to the result terminal (RESULT_OFFSET + resultIndex) + * @throws IllegalArgumentException if resultIndex is negative + * @throws IllegalStateException if condition count not set + */ + public int makeResult(int resultIndex) { + if (conditionCount == -1) { + throw new IllegalStateException("Must set condition count before creating results"); + } else if (resultIndex < 0) { + throw new IllegalArgumentException("Result index must be non-negative: " + resultIndex); + } else { + return Bdd.RESULT_OFFSET + resultIndex; + } + } + + /** + * Creates or retrieves a BDD node for the given variable and branches. + * + *

Applies BDD reduction rules: + *

    + *
  • Eliminates redundant tests where both branches are identical
  • + *
  • Ensures complement edges appear only on the low branch
  • + *
  • Reuses existing nodes via the unique table
  • + *
+ * + * @param var the variable index + * @param high the reference for when variable is true + * @param low the reference for when variable is false + * @return reference to the BDD node + */ + public int makeNode(int var, int high, int low) { + if (conditionCount >= 0 && (var < 0 || var >= conditionCount)) { + throw new IllegalArgumentException("Variable out of bounds: " + var); + } else if (high == low) { + // Reduction rule: if both branches are identical, skip this test + return high; + } + + // Complement edge canonicalization: ensure complement only on low branch. + // Don't apply this to result nodes or when branches contain results + boolean flip = shouldFlip(high, low); + if (flip) { + high = negate(high); + low = negate(low); + } + + // Check if this node already exists + Integer existing = uniqueTable.get(var, high, low); + if (existing != null) { + return applyFlip(flip, existing); + } else { + return insertNode(var, high, low, flip); + } + } + + private boolean shouldFlip(int high, int low) { + return isComplement(low) && !isResult(high) && !isResult(low); + } + + private int applyFlip(boolean flip, int idx) { + return flip ? negate(toReference(idx)) : toReference(idx); + } + + private int insertNode(int var, int high, int low, boolean flip) { + ensureCapacity(); + + int idx = nodeCount; + int base = idx * 3; + nodes[base] = var; + nodes[base + 1] = high; + nodes[base + 2] = low; + nodeCount++; + + uniqueTable.put(var, high, low, idx); + return applyFlip(flip, idx); + } + + private void ensureCapacity() { + if (nodeCount * 3 >= nodes.length) { + // Double the current capacity + int newCapacity = nodes.length * 2; + nodes = Arrays.copyOf(nodes, newCapacity); + } + } + + /** + * Negates a BDD reference (logical NOT). + * + * @param ref the reference to negate + * @return the negated reference + * @throws IllegalArgumentException if ref is a result terminal or invalid. + */ + public int negate(int ref) { + if (ref == 0 || isResult(ref)) { + throw new IllegalArgumentException( + "Cannot negate " + (ref == 0 ? "invalid reference: " : "result terminal: ") + ref); + } + return -ref; + } + + /** + * Checks if a reference has a complement edge. + * + * @param ref the reference to check + * @return true if complemented (negative) + */ + public boolean isComplement(int ref) { + return ref < 0; + } + + /** + * Checks if a reference is a boolean terminal (TRUE or FALSE). + * + * @param ref the reference to check + * @return true if boolean terminal + */ + public boolean isTerminal(int ref) { + return Math.abs(ref) == 1; + } + + /** + * Checks if a reference is a result terminal. + * + * @param ref the reference to check + * @return true if result terminal + */ + public boolean isResult(int ref) { + if (isTerminal(ref) || conditionCount == -1) { + return false; + } + return ref >= Bdd.RESULT_OFFSET; + } + + /** + * Checks if a reference is a leaf (terminal or result). + * + * @param ref the reference to check + * @return true if this is a leaf node + */ + private boolean isLeaf(int ref) { + return Math.abs(ref) == TRUE_REF || ref >= Bdd.RESULT_OFFSET; + } + + /** + * Gets the variable index for a BDD node. + * + * @param ref the BDD reference + * @return the variable index, or -1 for terminals + */ + public int getVariable(int ref) { + if (isLeaf(ref)) { + return -1; + } + + int nodeIndex = Math.abs(ref) - 1; + validateNodeIndex(nodeIndex); + return nodes[nodeIndex * 3]; + } + + /** + * Computes the cofactor of a BDD with respect to a variable assignment. + * + * @param bdd the BDD to restrict + * @param varIndex the variable to fix + * @param value the value to assign (true or false) + * @return the restricted BDD + */ + public int cofactor(int bdd, int varIndex, boolean value) { + // Terminals and results are unaffected by cofactoring + if (isLeaf(bdd)) { + return bdd; + } + + boolean complemented = isComplement(bdd); + int nodeIndex = toNodeIndex(bdd); + validateNodeIndex(nodeIndex); + + int base = nodeIndex * 3; + int nodeVar = nodes[base]; + + if (nodeVar == varIndex) { + // This node tests our variable, so take the appropriate branch + int child = value ? nodes[base + 1] : nodes[base + 2]; + // Only negate if child is not a result + return (complemented && !isResult(child)) ? negate(child) : child; + } else if (nodeVar > varIndex) { + // Variable doesn't appear in this BDD (due to ordering) + return bdd; + } else { + // Variable appears deeper, so recurse on both branches + int high = cofactor(nodes[base + 1], varIndex, value); + int low = cofactor(nodes[base + 2], varIndex, value); + int result = makeNode(nodeVar, high, low); + return (complemented && !isResult(result)) ? negate(result) : result; + } + } + + /** + * Computes the logical AND of two BDDs. + * + * @param f first operand + * @param g second operand + * @return f AND g + * @throws IllegalArgumentException if operands are result terminals + */ + public int and(int f, int g) { + validateBooleanOperands(f, g, "AND"); + return ite(f, g, makeFalse()); + } + + /** + * Computes the logical OR of two BDDs. + * + * @param f first operand + * @param g second operand + * @return f OR g + * @throws IllegalArgumentException if operands are result terminals + */ + public int or(int f, int g) { + validateBooleanOperands(f, g, "OR"); + return ite(f, makeTrue(), g); + } + + /** + * Computes if-then-else (ITE) operation: "if f then g else h". + * + *

This is the fundamental BDD operation from which all others are derived. + * Includes optimizations for special cases and complement edges. + * + * @param f the condition (must be boolean) + * @param g the "then" branch + * @param h the "else" branch + * @return the resulting BDD + * @throws IllegalArgumentException if f is a result terminal + */ + public int ite(int f, int g, int h) { + // Normalize complement edge on f + if (f < 0) { + f = -f; + int tmp = g; + g = h; + h = tmp; + } + + // Quick terminal cases + if (f == TRUE_REF || g == h) { + return g; + } else if (isResult(f)) { + throw new IllegalArgumentException("Condition f must be boolean, not a result terminal"); + } else if (!isResult(g) && !isResult(h)) { + // Boolean-only identities + if (g == TRUE_REF && h == FALSE_REF) { + return f; + } else if (g == FALSE_REF && h == TRUE_REF) { + return negate(f); + } else if (g == f) { + return or(f, h); + } else if (h == f) { + return and(f, g); + } else if (g == negate(f)) { + return and(negate(f), h); + } else if (h == negate(f)) { + return or(negate(f), g); + } else if (isComplement(g) && isComplement(h)) { + // Factor out common complement + return negate(ite(f, negate(g), negate(h))); + } + } + + Integer cached = iteCache.get(f, g, h); + if (cached != null) { + return cached; + } + + // Reserve cache slot to handle recursive calls + iteCache.put(f, g, h, 0); // placeholder + + // Shannon expansion + int v = getTopVariable(f, g, h); + int r0 = ite(cofactor(f, v, false), cofactor(g, v, false), cofactor(h, v, false)); + int r1 = ite(cofactor(f, v, true), cofactor(g, v, true), cofactor(h, v, true)); + + // Build result node and cache it + int result = makeNode(v, r1, r0); + iteCache.put(f, g, h, result); + return result; + } + + /** + * Reduces the BDD by eliminating redundant nodes. + * + * @param rootRef the root of the BDD to reduce + * @return the reduced BDD root + */ + public int reduce(int rootRef) { + if (isLeaf(rootRef)) { + return rootRef; + } + + // Peel off complement on the root + boolean rootComp = rootRef < 0; + int absRoot = rootComp ? negate(rootRef) : rootRef; + + // Allocate new nodes array + int[] newNodes = new int[nodeCount * 3]; + + // Clear and reuse the existing unique table + uniqueTable.clear(); + + // Initialize the terminal node + newNodes[0] = -1; + newNodes[1] = TRUE_REF; + newNodes[2] = FALSE_REF; + + // Prepare the visitation map + int[] oldToNew = new int[nodeCount]; + Arrays.fill(oldToNew, -1); + int[] newCount = {1}; // start after terminal + + // Recursively rebuild + int newRoot = reduceRec(absRoot, oldToNew, newNodes, newCount); + + // Swap in the new nodes array (trimmed to actual size) + this.nodes = Arrays.copyOf(newNodes, newCount[0] * 3); + this.nodeCount = newCount[0]; + clearCaches(); + + return rootComp ? negate(newRoot) : newRoot; + } + + private int reduceRec( + int ref, + int[] oldToNew, + int[] newNodes, + int[] newCount + ) { + if (isLeaf(ref)) { + return ref; + } + + // Peel complement + boolean comp = ref < 0; + int abs = comp ? negate(ref) : ref; + int idx = toNodeIndex(abs); + + // If already mapped, return it + int mapped = oldToNew[idx]; + if (mapped != -1) { + return comp ? negate(mapped) : mapped; + } + + // Recurse on children + int base = idx * 3; + int var = nodes[base]; + int hiNew = reduceRec(nodes[base + 1], oldToNew, newNodes, newCount); + int loNew = reduceRec(nodes[base + 2], oldToNew, newNodes, newCount); + + // Reduction rule + int resultAbs; + if (hiNew == loNew) { + resultAbs = hiNew; + } else { + // Canonicalize complement edges on the low branch + boolean flip = shouldFlip(hiNew, loNew); + if (flip) { + hiNew = negate(hiNew); + loNew = negate(loNew); + } + + // Lookup or create a new node + Integer existing = uniqueTable.get(var, hiNew, loNew); + if (existing != null) { + resultAbs = toReference(existing); + } else { + int nodeIdx = newCount[0]++; + int newBase = nodeIdx * 3; + newNodes[newBase] = var; + newNodes[newBase + 1] = hiNew; + newNodes[newBase + 2] = loNew; + uniqueTable.put(var, hiNew, loNew, nodeIdx); + resultAbs = toReference(nodeIdx); + } + + if (flip) { + resultAbs = negate(resultAbs); + } + } + + oldToNew[idx] = resultAbs; + return comp ? negate(resultAbs) : resultAbs; + } + + /** + * Finds the topmost variable among three BDDs. + */ + private int getTopVariable(int f, int g, int h) { + int minVar = Integer.MAX_VALUE; + minVar = updateMinVariable(minVar, f); + minVar = updateMinVariable(minVar, g); + minVar = updateMinVariable(minVar, h); + return (minVar == Integer.MAX_VALUE) ? -1 : minVar; + } + + private int updateMinVariable(int currentMin, int ref) { + int absRef = Math.abs(ref); + if (absRef > 1 && absRef < Bdd.RESULT_OFFSET) { + return Math.min(currentMin, nodes[(absRef - 1) * 3]); + } + return currentMin; + } + + /** + * Clears all operation caches. + */ + public void clearCaches() { + iteCache.clear(); + } + + /** + * Clear out the state of the builder, but reuse the existing arrays, maps, etc. + * + * @return this builder + */ + public BddBuilder reset() { + clearCaches(); + uniqueTable.clear(); + Arrays.fill(nodes, 0, nodeCount * 3, 0); + nodeCount = 1; + initializeTerminalNode(); + conditionCount = -1; + return this; + } + + /** + * Builds a BDD from the current state of the builder. + * + * @return a new BDD instance + * @throws IllegalStateException if condition count has not been set + */ + Bdd build(int rootRef, int resultCount) { + if (conditionCount == -1) { + throw new IllegalStateException("Condition count must be set before building BDD"); + } + + int[] n = Arrays.copyOf(nodes, nodeCount * 3); + return new Bdd(rootRef, conditionCount, resultCount, nodeCount, n); + } + + private void validateBooleanOperands(int f, int g, String operation) { + if (isResult(f) || isResult(g)) { + throw new IllegalArgumentException("Cannot perform " + operation + " on result terminals"); + } + } + + private int toNodeIndex(int ref) { + return Math.abs(ref) - 1; + } + + private int toReference(int nodeIndex) { + return nodeIndex + 1; + } + + private void validateNodeIndex(int nodeIndex) { + if (nodeIndex >= nodeCount || nodeIndex < 0) { + throw new IllegalStateException("Invalid node index: " + nodeIndex); + } + } + + private void initializeTerminalNode() { + nodes[0] = -1; + nodes[1] = TRUE_REF; + nodes[2] = FALSE_REF; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java new file mode 100644 index 00000000000..f9aea0e4d1d --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java @@ -0,0 +1,166 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.logging.Logger; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.ConditionReference; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.rulesengine.logic.cfg.CfgNode; +import software.amazon.smithy.rulesengine.logic.cfg.ConditionNode; +import software.amazon.smithy.rulesengine.logic.cfg.ResultNode; +import software.amazon.smithy.utils.SmithyInternalApi; + +/** + * BDD compiler that builds a BDD from a CFG. + */ +@SmithyInternalApi +public final class BddCompiler { + private static final Logger LOGGER = Logger.getLogger(BddCompiler.class.getName()); + + private final Cfg cfg; + private final BddBuilder bddBuilder; + private final OrderingStrategy orderingStrategy; + + // Condition ordering + private List orderedConditions; + private Map conditionToIndex; + + // Result indexing + private final Map ruleToIndex = new HashMap<>(); + private final List indexedResults = new ArrayList<>(); + private int nextResultIndex = 0; + private int noMatchIndex = -1; + + // Simple cache to avoid recomputing identical subgraphs + private final Map nodeCache = new HashMap<>(); + + /** + * @param cfg CFG to convert to a BDD. + */ + public BddCompiler(Cfg cfg) { + this(cfg, new BddBuilder()); + } + + BddCompiler(Cfg cfg, BddBuilder bddBuilder) { + this(cfg, OrderingStrategy.initialOrdering(cfg), bddBuilder); + } + + BddCompiler(Cfg cfg, OrderingStrategy orderingStrategy, BddBuilder bddBuilder) { + this.cfg = Objects.requireNonNull(cfg, "CFG cannot be null"); + this.orderingStrategy = Objects.requireNonNull(orderingStrategy, "Ordering strategy cannot be null"); + this.bddBuilder = Objects.requireNonNull(bddBuilder, "BDD builder cannot be null"); + } + + /** + * Compile the CFG into a BDD. + * + * @return the compiled BDD. + */ + public Bdd compile() { + long start = System.currentTimeMillis(); + extractAndOrderConditions(); + + // Set the condition count in the builder + bddBuilder.setConditionCount(orderedConditions.size()); + + // Create the "no match" terminal + noMatchIndex = getOrCreateResultIndex(NoMatchRule.INSTANCE); + int rootRef = convertCfgToBdd(cfg.getRoot()); + rootRef = bddBuilder.reduce(rootRef); + Bdd bdd = bddBuilder.build(rootRef, indexedResults.size()); + + long elapsed = System.currentTimeMillis() - start; + LOGGER.fine(String.format( + "BDD compilation complete: %d conditions, %d results, %d BDD nodes in %dms", + orderedConditions.size(), + indexedResults.size(), + bdd.getNodeCount(), + elapsed)); + + return bdd; + } + + /** + * The ordered result rules after BDD compilation. + * + * @return ordered BDD result rules. + */ + public List getIndexedResults() { + return indexedResults; + } + + /** + * Get the ordered conditions referenced in the compiled BDD. + * + * @return the ordered BDD conditions. + */ + public List getOrderedConditions() { + return orderedConditions; + } + + private int convertCfgToBdd(CfgNode cfgNode) { + Integer cached = nodeCache.get(cfgNode); + if (cached != null) { + return cached; + } + + int result; + if (cfgNode == null) { + result = bddBuilder.makeResult(noMatchIndex); + } else if (cfgNode instanceof ResultNode) { + Rule rule = ((ResultNode) cfgNode).getResult(); + result = bddBuilder.makeResult(getOrCreateResultIndex(rule)); + } else { + ConditionNode cn = (ConditionNode) cfgNode; + ConditionReference ref = cn.getCondition(); + int varIdx = conditionToIndex.get(ref.getCondition()); + + // Recursively build the two branches + int hi = convertCfgToBdd(cn.getTrueBranch()); + int lo = convertCfgToBdd(cn.getFalseBranch()); + + // If the original rule said "not condition", swap branches + if (ref.isNegated()) { + int tmp = hi; + hi = lo; + lo = tmp; + } + + // Build the pure boolean test for variable varIdx + int test = bddBuilder.makeNode(varIdx, bddBuilder.makeTrue(), bddBuilder.makeFalse()); + + // Combine with ITE (reduces and merges) + result = bddBuilder.ite(test, hi, lo); + } + + nodeCache.put(cfgNode, result); + return result; + } + + private int getOrCreateResultIndex(Rule rule) { + return ruleToIndex.computeIfAbsent(rule, r -> { + int idx = nextResultIndex++; + indexedResults.add(r); + return idx; + }); + } + + private void extractAndOrderConditions() { + orderedConditions = orderingStrategy.orderConditions(cfg.getConditions()); + conditionToIndex = new LinkedHashMap<>(); + for (int i = 0; i < orderedConditions.size(); i++) { + conditionToIndex.put(orderedConditions.get(i), i); + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java new file mode 100644 index 00000000000..eb60b7abba1 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java @@ -0,0 +1,374 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.logging.Logger; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.ConditionEvaluator; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.rulesengine.logic.cfg.CfgNode; +import software.amazon.smithy.rulesengine.logic.cfg.ConditionNode; +import software.amazon.smithy.rulesengine.logic.cfg.ResultNode; + +/** + * Verifies functional equivalence between a CFG and its BDD representation. + * + *

This verifier uses structural equivalence checking to ensure that both representations produce the same result. + * When the BDD has fewer than 20 conditions, the checking is exhaustive. When there are more, random samples are + * checked up to the earlier of max samples being reached or the max duration being reached. + */ +public final class BddEquivalenceChecker { + + private static final Logger LOGGER = Logger.getLogger(BddEquivalenceChecker.class.getName()); + + private static final int EXHAUSTIVE_THRESHOLD = 20; + private static final int DEFAULT_MAX_SAMPLES = 1_000_000; + private static final Duration DEFAULT_TIMEOUT = Duration.ofMinutes(1); + + private final Cfg cfg; + private final Bdd bdd; + private final List conditions; + private final List results; + private final List parameters; + private final Map conditionToIndex = new HashMap<>(); + + private int maxSamples = DEFAULT_MAX_SAMPLES; + private Duration timeout = DEFAULT_TIMEOUT; + + private int testsRun = 0; + private long startTime; + + public static BddEquivalenceChecker of(Cfg cfg, Bdd bdd, List conditions, List results) { + return new BddEquivalenceChecker(cfg, bdd, conditions, results); + } + + private BddEquivalenceChecker(Cfg cfg, Bdd bdd, List conditions, List results) { + this.cfg = cfg; + this.bdd = bdd; + this.conditions = conditions; + this.results = results; + this.parameters = new ArrayList<>(cfg.getParameters().toList()); + + for (int i = 0; i < conditions.size(); i++) { + conditionToIndex.put(conditions.get(i), i); + } + } + + /** + * Sets the maximum number of samples to test for large condition sets. + * + *

Defaults to a max of 1M samples. Set to {@code <= 0} to disable the max. + * + * @param maxSamples the maximum number of samples + * @return this verifier for method chaining + */ + public BddEquivalenceChecker setMaxSamples(int maxSamples) { + if (maxSamples < 1) { + maxSamples = Integer.MAX_VALUE; + } + this.maxSamples = maxSamples; + return this; + } + + /** + * Sets the maximum amount of time to take for the verification (runs until timeout or max samples met). + * + *

Defaults to a 1-minute timeout if not overridden. + * + * @param timeout the timeout duration + * @return this verifier for method chaining + */ + public BddEquivalenceChecker setMaxDuration(Duration timeout) { + this.timeout = timeout; + return this; + } + + /** + * Verifies that the BDD produces identical results to the CFG. + * + * @throws VerificationException if any discrepancy is found + */ + public void verify() { + startTime = System.currentTimeMillis(); + verifyResults(); + testsRun = 0; + + LOGGER.info(() -> String.format("Verifying BDD with %d conditions (max samples: %d, timeout: %s)", + bdd.getConditionCount(), + maxSamples, + timeout)); + + if (bdd.getConditionCount() <= EXHAUSTIVE_THRESHOLD) { + verifyExhaustive(); + } else { + verifyWithLimits(); + } + + LOGGER.info(String.format("BDD verification passed: %d tests in %s", testsRun, getElapsedDuration())); + } + + private void verifyResults() { + Set cfgResults = new HashSet<>(); + for (CfgNode node : cfg) { + if (node instanceof ResultNode) { + Rule result = ((ResultNode) node).getResult(); + if (result != null) { + cfgResults.add(result); + } + } + } + + // Remove the NoMatchRule that's added by default. It's not in the CFG. + Set bddResults = new HashSet<>(results); + bddResults.removeIf(v -> v == NoMatchRule.INSTANCE); + + if (!cfgResults.equals(bddResults)) { + Set inCfgOnly = new HashSet<>(cfgResults); + inCfgOnly.removeAll(bddResults); + Set inBddOnly = new HashSet<>(bddResults); + inBddOnly.removeAll(cfgResults); + throw new IllegalStateException(String.format( + "Result mismatch: CFG has %d results, BDD has %d results (excluding NoMatchRule).%n" + + "In CFG only: %s%n" + + "In BDD only: %s", + cfgResults.size(), + bddResults.size(), + inCfgOnly, + inBddOnly)); + } + } + + /** + * Exhaustively tests all possible condition combinations. + */ + private void verifyExhaustive() { + long totalCombinations = 1L << bdd.getConditionCount(); + LOGGER.info(() -> "Running exhaustive verification with " + totalCombinations + " combinations"); + for (long mask = 0; mask < totalCombinations; mask++) { + verifyCase(mask); + if (hasEitherLimitBeenExceeded()) { + LOGGER.info(String.format("Exhaustive verification stopped after %d tests " + + "(limit: %d samples or %s timeout)", testsRun, maxSamples, timeout)); + break; + } + } + } + + /** + * Verifies with configured limits (samples and timeout). + * Continues until EITHER limit is reached: maxSamples reached OR timeout exceeded. + */ + private void verifyWithLimits() { + LOGGER.info(() -> String.format("Running limited verification (will stop at %d samples OR %s timeout)", + maxSamples, + timeout)); + verifyCriticalCases(); + + while (!hasEitherLimitBeenExceeded()) { + long mask = randomMask(); + verifyCase(mask); + if (testsRun % 10000 == 0 && testsRun > 0) { + LOGGER.fine(() -> String.format("Progress: %d tests run, %s elapsed", testsRun, getElapsedDuration())); + } + } + + LOGGER.info(() -> String.format("Verification complete: %d tests run in %s", testsRun, getElapsedDuration())); + } + + /** + * Tests critical edge cases that are likely to expose bugs. + */ + private void verifyCriticalCases() { + LOGGER.fine("Testing critical edge cases"); + + // All conditions false + verifyCase(0); + + // All conditions true + verifyCase((1L << bdd.getConditionCount()) - 1); + + // Each condition true individually + for (int i = 0; i < bdd.getConditionCount() && !hasEitherLimitBeenExceeded(); i++) { + verifyCase(1L << i); + } + + // Each condition false individually (all others true) + long allTrue = (1L << bdd.getConditionCount()) - 1; + for (int i = 0; i < bdd.getConditionCount() && !hasEitherLimitBeenExceeded(); i++) { + verifyCase(allTrue ^ (1L << i)); + } + + // Alternating patterns: 0101... (even conditions false, odd true) + if (!hasEitherLimitBeenExceeded()) { + verifyCase(0x5555555555555555L & ((1L << bdd.getConditionCount()) - 1)); + } + + // Pattern: 1010... (even conditions true, odd false) + if (!hasEitherLimitBeenExceeded()) { + verifyCase(0xAAAAAAAAAAAAAAAAL & ((1L << bdd.getConditionCount()) - 1)); + } + } + + private boolean hasEitherLimitBeenExceeded() { + return testsRun >= maxSamples || isTimedOut(); + } + + private boolean isTimedOut() { + return getElapsedDuration().compareTo(timeout) >= 0; + } + + private Duration getElapsedDuration() { + return Duration.ofMillis(System.currentTimeMillis() - startTime); + } + + private void verifyCase(long mask) { + testsRun++; + + // Create evaluators that will return fixed values for conditions + FixedMaskEvaluator maskEvaluator = new FixedMaskEvaluator(mask); + Rule cfgResult = evaluateCfgWithMask(maskEvaluator); + Rule bddResult = evaluateBdd(mask); + + if (!resultsEqual(cfgResult, bddResult)) { + StringBuilder errorMsg = new StringBuilder(); + errorMsg.append("BDD verification mismatch found!\n"); + errorMsg.append("Test case #").append(testsRun).append("\n"); + errorMsg.append("Condition mask: ").append(Long.toBinaryString(mask)).append("\n"); + errorMsg.append("\nCondition details:\n"); + for (int i = 0; i < conditions.size(); i++) { + Condition condition = conditions.get(i); + boolean value = (mask & (1L << i)) != 0; + errorMsg.append(" Condition ") + .append(i) + .append(" [") + .append(value) + .append("]: ") + .append(condition) + .append("\n"); + } + errorMsg.append("\nResults:\n"); + errorMsg.append(" CFG result: ").append(describeResult(cfgResult)).append("\n"); + errorMsg.append(" BDD result: ").append(describeResult(bddResult)); + throw new VerificationException(errorMsg.toString()); + } + } + + private Rule evaluateCfgWithMask(ConditionEvaluator maskEvaluator) { + Map cfgConditionToIndex = new HashMap<>(); + Condition[] cfgConditions = cfg.getConditions(); + for (int i = 0; i < cfgConditions.length; i++) { + cfgConditionToIndex.put(cfgConditions[i], i); + } + + CfgNode result = evaluateCfgNode(cfg.getRoot(), cfgConditionToIndex, maskEvaluator); + if (result instanceof ResultNode) { + return ((ResultNode) result).getResult(); + } + + return null; + } + + // Recursively evaluates a CFG node. + private CfgNode evaluateCfgNode( + CfgNode node, + Map conditionToIndex, + ConditionEvaluator maskEvaluator + ) { + if (node instanceof ResultNode) { + return node; + } + + if (node instanceof ConditionNode) { + ConditionNode condNode = (ConditionNode) node; + Condition condition = condNode.getCondition().getCondition(); + + Integer index = conditionToIndex.get(condition); + if (index == null) { + throw new IllegalStateException("Condition not found in CFG: " + condition); + } + + boolean conditionResult = maskEvaluator.test(index); + + // Handle negation if the condition reference is negated + if (condNode.getCondition().isNegated()) { + conditionResult = !conditionResult; + } + + // Follow the appropriate branch + if (conditionResult) { + return evaluateCfgNode(condNode.getTrueBranch(), conditionToIndex, maskEvaluator); + } else { + return evaluateCfgNode(condNode.getFalseBranch(), conditionToIndex, maskEvaluator); + } + } + + throw new IllegalStateException("Unknown CFG node type: " + node); + } + + private Rule evaluateBdd(long mask) { + FixedMaskEvaluator evaluator = new FixedMaskEvaluator(mask); + int resultIndex = bdd.evaluate(evaluator); + return resultIndex < 0 ? null : results.get(resultIndex); + } + + private boolean resultsEqual(Rule r1, Rule r2) { + if (r1 == r2) { + return true; + } else if (r1 == null || r2 == null) { + return false; + } else { + return r1.withConditions(Collections.emptyList()).equals(r2.withConditions(Collections.emptyList())); + } + } + + // Generates a random bit mask for sampling. + private long randomMask() { + long mask = 0; + for (int i = 0; i < bdd.getConditionCount(); i++) { + if (Math.random() < 0.5) { + mask |= (1L << i); + } + } + return mask; + } + + private String describeResult(Rule rule) { + return rule == null ? "null (no match)" : rule.toString(); + } + + // A condition evaluator that returns values based on a fixed bit mask. + private static class FixedMaskEvaluator implements ConditionEvaluator { + private final long mask; + + FixedMaskEvaluator(long mask) { + this.mask = mask; + } + + @Override + public boolean test(int conditionIndex) { + return (mask & (1L << conditionIndex)) != 0; + } + } + + /** + * Exception thrown when verification fails. + */ + public static class VerificationException extends RuntimeException { + public VerificationException(String message) { + super(message); + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddFormatter.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddFormatter.java new file mode 100644 index 00000000000..124be11f74d --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddFormatter.java @@ -0,0 +1,199 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.Writer; +import java.nio.charset.StandardCharsets; + +/** + * Formats BDD node structures to a writer. + */ +public final class BddFormatter { + + private final Bdd bdd; + private final Writer writer; + private final String indent; + + /** + * Creates a BDD formatter. + * + * @param bdd the BDD to format + * @param writer the writer to format to + * @param indent the indentation string + */ + public BddFormatter(Bdd bdd, Writer writer, String indent) { + this.bdd = bdd; + this.writer = writer; + this.indent = indent; + } + + /** + * Formats a BDD to a string. + * + * @param bdd the BDD to format + * @return a formatted string representation + */ + public static String format(Bdd bdd) { + return format(bdd, ""); + } + + /** + * Formats a BDD to a string with custom indent. + * + * @param bdd the BDD to format + * @param indent the indentation string + * @return a formatted string representation + */ + public static String format(Bdd bdd, String indent) { + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + Writer writer = new OutputStreamWriter(baos, StandardCharsets.UTF_8); + new BddFormatter(bdd, writer, indent).format(); + writer.flush(); + return baos.toString(StandardCharsets.UTF_8.name()); + } catch (IOException e) { + // Should never happen with ByteArrayOutputStream + throw new RuntimeException("Failed to format BDD", e); + } + } + + /** + * Formats the BDD structure. + * + * @throws IOException if writing fails + */ + public void format() throws IOException { + // Calculate formatting widths + FormatContext ctx = calculateFormatContext(); + + // Write header + writer.write(indent); + writer.write("Bdd {\n"); + + // Write counts + writer.write(indent); + writer.write(" conditions: "); + writer.write(String.valueOf(bdd.getConditionCount())); + writer.write("\n"); + + writer.write(indent); + writer.write(" results: "); + writer.write(String.valueOf(bdd.getResultCount())); + writer.write("\n"); + + // Write root + writer.write(indent); + writer.write(" root: "); + writer.write(formatReference(bdd.getRootRef())); + writer.write("\n"); + + // Write nodes + writer.write(indent); + writer.write(" nodes ("); + writer.write(String.valueOf(bdd.getNodeCount())); + writer.write("):\n"); + + for (int i = 0; i < bdd.getNodeCount(); i++) { + writer.write(indent); + writer.write(" "); + writer.write(String.format("%" + ctx.indexWidth + "d", i)); + writer.write(": "); + + if (i == 0 && bdd.getVariable(0) == -1) { + writer.write("terminal"); + } else { + formatNode(i, ctx); + } + writer.write("\n"); + } + + writer.write(indent); + writer.write("}"); + } + + private FormatContext calculateFormatContext() { + int nodeCount = bdd.getNodeCount(); + int maxVarIdx = -1; + + // Scan nodes to find max variable index + for (int i = 1; i < nodeCount; i++) { + int varIdx = bdd.getVariable(i); + if (varIdx >= 0) { + maxVarIdx = Math.max(maxVarIdx, varIdx); + } + } + + // Calculate widths + int conditionCount = bdd.getConditionCount(); + int resultCount = bdd.getResultCount(); + int conditionWidth = conditionCount > 0 ? String.valueOf(conditionCount - 1).length() + 1 : 2; + int resultWidth = resultCount > 0 ? String.valueOf(resultCount - 1).length() + 1 : 2; + int varWidth = Math.max(Math.max(conditionWidth, resultWidth), String.valueOf(maxVarIdx).length()); + int indexWidth = String.valueOf(nodeCount - 1).length(); + + return new FormatContext(varWidth, indexWidth); + } + + private void formatNode(int nodeIndex, FormatContext ctx) throws IOException { + writer.write("["); + + // Variable reference + int varIdx = bdd.getVariable(nodeIndex); + String varRef = formatVariableIndex(varIdx); + writer.write(String.format("%" + ctx.varWidth + "s", varRef)); + + // High and low references + writer.write(", "); + writer.write(String.format("%6s", formatReference(bdd.getHigh(nodeIndex)))); + writer.write(", "); + writer.write(String.format("%6s", formatReference(bdd.getLow(nodeIndex)))); + writer.write("]"); + } + + private String formatVariableIndex(int varIdx) { + if (bdd.getConditionCount() > 0 && varIdx < bdd.getConditionCount()) { + return "C" + varIdx; + } else if (bdd.getConditionCount() > 0 && bdd.getResultCount() > 0) { + return "R" + (varIdx - bdd.getConditionCount()); + } else { + return String.valueOf(varIdx); + } + } + + /** + * Formats a BDD reference (node pointer) to a human-readable string. + * + * @param ref the reference to format + * @return the formatted reference string + */ + public static String formatReference(int ref) { + if (ref == 0) { + return "INVALID"; + } else if (ref == 1) { + return "TRUE"; + } else if (ref == -1) { + return "FALSE"; + } else if (ref >= Bdd.RESULT_OFFSET) { + return "R" + (ref - Bdd.RESULT_OFFSET); + } else if (ref < 0) { + return "!" + (-ref - 1); + } else { + return String.valueOf(ref - 1); + } + } + + private static class FormatContext { + final int varWidth; + final int indexWidth; + + FormatContext(int varWidth, int indexWidth) { + this.varWidth = varWidth; + this.indexWidth = indexWidth; + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddNodeConsumer.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddNodeConsumer.java new file mode 100644 index 00000000000..4c34cea386a --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddNodeConsumer.java @@ -0,0 +1,19 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +/** + * Consumer that receives every node in a {@link Bdd}. + */ +public interface BddNodeConsumer { + /** + * Receives a BDD node. + * + * @param var Variable. + * @param high High reference. + * @param low Low reference. + */ + void accept(int var, int high, int low); +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CfgConeAnalysis.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CfgConeAnalysis.java new file mode 100644 index 00000000000..f7645ac2b3d --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CfgConeAnalysis.java @@ -0,0 +1,160 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.rulesengine.logic.cfg.CfgNode; +import software.amazon.smithy.rulesengine.logic.cfg.ConditionNode; +import software.amazon.smithy.rulesengine.logic.cfg.ResultNode; + +/** + * Analyzes a CFG to compute cone information for each condition. + * + *

A "cone" is the subgraph of the CFG that is reachable from a given condition node. Think of it as the + * "downstream impact" of a condition. A condition with a large cone controls many downstream decisions (high impact). + * A condition with many reachable result nodes in its cone affects many endpoints. Conditions that appear early in the + * CFG (low dominator depth) are "gates" that control access to large portions of the decision tree. + */ +final class CfgConeAnalysis { + /** + * Dominator depth for each condition, or how many edges from CFG root to first occurrence. + * Initialized to MAX_VALUE, then updated to minimum depth encountered during traversal. + */ + private final int[] dominatorDepth; + + /** Number of result nodes (endpoints/errors) reachable from each condition's cone. */ + private final int[] reachableResults; + + /** Cache of computed cone information for each CFG node to avoid redundant traversals. */ + private final Map coneCache = new HashMap<>(); + + /** Maps conditions to their indices for quick lookups. */ + private final Map conditionToIndex; + + /** + * Creates a new cone analysis for the given CFG and conditions. + * + * @param cfg the control flow graph to analyze + * @param conditions array of conditions in the rule set + * @param conditionToIndex mapping from conditions to their indices + */ + public CfgConeAnalysis(Cfg cfg, Condition[] conditions, Map conditionToIndex) { + this.conditionToIndex = conditionToIndex; + int n = conditions.length; + this.dominatorDepth = new int[n]; + this.reachableResults = new int[n]; + Arrays.fill(dominatorDepth, Integer.MAX_VALUE); + analyzeCfgNode(cfg.getRoot(), 0); + } + + /** + * Recursively analyzes a CFG node and its subtree, computing cone information. + * + * @param node the current CFG node being analyzed + * @param depth the current depth in the CFG traversal (edges from root) + * @return cone information for this subtree + */ + private ConeInfo analyzeCfgNode(CfgNode node, int depth) { + if (node == null) { + return ConeInfo.empty(); + } + + ConeInfo cached = coneCache.get(node); + if (cached != null) { + if (cached.inProgress) { + throw new IllegalStateException("Cycle detected in CFG during cone analysis: " + node); + } + return cached; + } + + // Cycle guard: if a transform accidentally introduced a cycle, fail fast. + coneCache.put(node, ConeInfo.IN_PROGRESS); + ConeInfo info; + + if (node instanceof ResultNode) { + info = ConeInfo.singleResult(); + } else if (node instanceof ConditionNode) { + ConditionNode condNode = (ConditionNode) node; + Condition condition = condNode.getCondition().getCondition(); + Integer conditionIdx = conditionToIndex.get(condition); + if (conditionIdx == null) { + throw new IllegalStateException("Condition not indexed in CFG: " + condition); + } + + // Handle conditions that appear multiple times by updating dominator depth. + // Keep the minimum depth where this condition appears. + dominatorDepth[conditionIdx] = Math.min(dominatorDepth[conditionIdx], depth); + + ConeInfo trueBranchCone = analyzeCfgNode(condNode.getTrueBranch(), depth + 1); + ConeInfo falseBranchCone = analyzeCfgNode(condNode.getFalseBranch(), depth + 1); + info = ConeInfo.combine(trueBranchCone, falseBranchCone); + + // Update the maximum result count this condition can influence + reachableResults[conditionIdx] = Math.max(reachableResults[conditionIdx], info.resultNodes); + } else { + throw new UnsupportedOperationException("Unknown node type: " + node); + } + + coneCache.put(node, info); + return info; + } + + /** + * Gets the dominator depth of a condition, the minimum depth at which this condition appears in the CFG. + * + *

Lower values indicate conditions that appear earlier in the decision tree and have more influence over + * the overall control flow. + * + * @param conditionIdx the index of the condition + * @return the minimum depth, or Integer.MAX_VALUE if never encountered + */ + public int dominatorDepth(int conditionIdx) { + return dominatorDepth[conditionIdx]; + } + + /** + * Gets the cone size as the number of reachable result nodes for a condition, representing how many different + * endpoints/errors can be reached downstream of the condition. + * + *

Larger values indicate conditions that have broader impact on the final outcome. + * + * @param conditionIdx the index of the condition + * @return the number of result nodes in this condition's cone + */ + public int coneSize(int conditionIdx) { + return reachableResults[conditionIdx]; + } + + private static final class ConeInfo { + private static final ConeInfo IN_PROGRESS = new ConeInfo(0, true); + + final int resultNodes; + final boolean inProgress; + + private ConeInfo(int resultNodes, boolean inProgress) { + this.resultNodes = resultNodes; + this.inProgress = inProgress; + } + + private static ConeInfo empty() { + return new ConeInfo(0, false); + } + + private static ConeInfo singleResult() { + return new ConeInfo(1, false); + } + + private static ConeInfo combine(ConeInfo trueBranch, ConeInfo falseBranch) { + if (trueBranch.inProgress || falseBranch.inProgress) { + throw new IllegalStateException("Cycle detected in CFG during cone analysis (branch in-progress)"); + } + return new ConeInfo(trueBranch.resultNodes + falseBranch.resultNodes, false); + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/InitialOrdering.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/InitialOrdering.java new file mode 100644 index 00000000000..7ed2dba09df --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/InitialOrdering.java @@ -0,0 +1,148 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.BitSet; +import java.util.List; +import java.util.Map; +import java.util.logging.Logger; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.rulesengine.logic.cfg.ConditionDependencyGraph; + +/** + * Orders conditions by following the natural structure of the CFG. + * + *

This strategy has proven to be the most effective for BDD construction because it preserves the locality that + * rule authors built into their decision trees. Conditions that are evaluated together in the original rules stay + * together in the BDD, enabling better node sharing. This ordering implementation flattens the tree structure while + * respecting data dependencies. + */ +final class InitialOrdering implements OrderingStrategy { + private static final Logger LOGGER = Logger.getLogger(InitialOrdering.class.getName()); + + /** How many distinct consumers make an isSet() a "gate". */ + private static final int GATE_SUCCESSOR_THRESHOLD = 2; + + private final Cfg cfg; + + InitialOrdering(Cfg cfg) { + this.cfg = cfg; + } + + @Override + public List orderConditions(Condition[] conditions) { + long startTime = System.currentTimeMillis(); + + ConditionDependencyGraph deps = new ConditionDependencyGraph(Arrays.asList(conditions)); + Map conditionToIndex = deps.getConditionToIndex(); + CfgConeAnalysis cones = new CfgConeAnalysis(cfg, conditions, conditionToIndex); + List order = buildCfgOrder(conditions, deps, cones); + + List result = new ArrayList<>(); + for (int id : order) { + result.add(conditions[id]); + } + + long elapsed = System.currentTimeMillis() - startTime; + LOGGER.info(() -> String.format("Initial ordering: %d conditions in %dms", conditions.length, elapsed)); + return result; + } + + // Builds an ordering using a topological sort that prefers conditions based on their position in the CFG. + private List buildCfgOrder(Condition[] conditions, ConditionDependencyGraph deps, CfgConeAnalysis cones) { + List result = new ArrayList<>(); + BitSet placed = new BitSet(); + BitSet ready = new BitSet(); + + // Start with conditions that have no dependencies + for (int i = 0; i < conditions.length; i++) { + if (deps.getPredecessorCount(i) == 0) { + ready.set(i); + } + } + + while (!ready.isEmpty()) { + int chosen = getNext(ready, cones, conditions, deps); + result.add(chosen); + placed.set(chosen); + ready.clear(chosen); + + // Make successors ready if all their dependencies are satisfied + for (int succ : deps.getSuccessors(chosen)) { + if (!placed.get(succ) && allPredecessorsPlaced(succ, deps, placed)) { + ready.set(succ); + } + } + } + + if (result.size() != conditions.length) { + throw new IllegalStateException("Topological ordering incomplete (possible cyclic deps). Placed=" + + result.size() + " of " + conditions.length); + } + + return result; + } + + /** + * Selects the next condition to place based on CFG structure. + * + *

1. Pick conditions closest to the CFG root (minimum depth) + * 2. Prefer "gate" conditions that guard many branches + * 3. Break ties with cone size (bigger is more discriminating) + * 4. Tie-break with ID for determinism + */ + private int getNext(BitSet ready, CfgConeAnalysis cones, Condition[] conditions, ConditionDependencyGraph deps) { + int best = -1; + int bestDepth = Integer.MAX_VALUE; + int bestCone = -1; + boolean bestIsGate = false; + + for (int i = ready.nextSetBit(0); i >= 0; i = ready.nextSetBit(i + 1)) { + int depth = cones.dominatorDepth(i); + if (depth > bestDepth) { + continue; // Skip if worse + } + + int cone = cones.coneSize(i); + boolean isGate = isIsSet(conditions[i]) && deps.getSuccessorCount(i) > GATE_SUCCESSOR_THRESHOLD; + + if (depth < bestDepth) { + // New best if shallower + best = i; + bestDepth = depth; + bestCone = cone; + bestIsGate = isGate; + } else if (!bestIsGate && isGate) { + // Gates win + best = i; + bestCone = cone; + bestIsGate = true; + } else if (bestIsGate == isGate && (cone > bestCone || (cone == bestCone && i < best))) { + // Same gate status, so pick larger cone or lower ID for stability + best = i; + bestCone = cone; + } + } + + return best; + } + + private boolean allPredecessorsPlaced(int id, ConditionDependencyGraph deps, BitSet placed) { + for (int pred : deps.getPredecessors(id)) { + if (!placed.get(pred)) { + return false; + } + } + return true; + } + + private static boolean isIsSet(Condition c) { + return c.getFunction().getFunctionDefinition() == IsSet.getDefinition(); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java new file mode 100644 index 00000000000..a19f6b36c98 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java @@ -0,0 +1,101 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.function.Function; +import java.util.logging.Logger; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; + +/** + * Reverses the node ordering in a BDD from bottom-up to top-down for better cache locality. + * + *

This transformation reverses the node array (except the terminal at index 0) + * and updates all references throughout the BDD to maintain correctness. + */ +public final class NodeReversal implements Function { + + private static final Logger LOGGER = Logger.getLogger(NodeReversal.class.getName()); + + @Override + public EndpointBddTrait apply(EndpointBddTrait trait) { + Bdd reversedBdd = reverse(trait.getBdd()); + // Only rebuild the trait if the BDD actually changed + return reversedBdd == trait.getBdd() ? trait : trait.toBuilder().bdd(reversedBdd).build(); + } + + /** + * Reverses the node ordering in a BDD. + * + * @param bdd the BDD to reverse + * @return the reversed BDD, or the original if too small to reverse + */ + public static Bdd reverse(Bdd bdd) { + LOGGER.info("Starting BDD node reversal optimization"); + int nodeCount = bdd.getNodeCount(); + + if (nodeCount <= 2) { + return bdd; + } + + // Create the index mapping: old index -> new index + // Index 0 (terminal) stays at 0 + int[] oldToNew = new int[nodeCount]; + oldToNew[0] = 0; + + // Reverse indices for non-terminal nodes + for (int oldIdx = 1; oldIdx < nodeCount; oldIdx++) { + int newIdx = nodeCount - oldIdx; + oldToNew[oldIdx] = newIdx; + } + + // Remap the root reference + int newRoot = remapReference(bdd.getRootRef(), oldToNew); + + // Create reversed BDD using streaming constructor + return new Bdd(newRoot, bdd.getConditionCount(), bdd.getResultCount(), nodeCount, consumer -> { + // Terminal stays at index 0 + consumer.accept(bdd.getVariable(0), bdd.getHigh(0), bdd.getLow(0)); + + // Add nodes in reverse order, updating their references + for (int oldIdx = nodeCount - 1; oldIdx >= 1; oldIdx--) { + int var = bdd.getVariable(oldIdx); + int high = remapReference(bdd.getHigh(oldIdx), oldToNew); + int low = remapReference(bdd.getLow(oldIdx), oldToNew); + consumer.accept(var, high, low); + } + }); + } + + /** + * Remaps a reference through the index mapping. + * + * @param ref the reference to remap + * @param oldToNew the index mapping array + * @return the remapped reference + */ + private static int remapReference(int ref, int[] oldToNew) { + // Return result references as-is. + if (ref == 0) { + return 0; + } else if (ref == 1 || ref == -1) { + return ref; + } else if (ref >= Bdd.RESULT_OFFSET) { + return ref; + } + + // Handle regular node references (with possible complement) + boolean isComplemented = ref < 0; + int absRef = isComplemented ? -ref : ref; + int oldIdx = absRef - 1; // convert 1-based to 0-based + + if (oldIdx >= oldToNew.length) { + throw new IllegalStateException("Invalid reference: " + ref); + } + + int newIdx = oldToNew[oldIdx]; + int newRef = newIdx + 1; + return isComplemented ? -newRef : newRef; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/OrderingStrategy.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/OrderingStrategy.java new file mode 100644 index 00000000000..3378c319427 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/OrderingStrategy.java @@ -0,0 +1,40 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.List; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; + +/** + * Strategy interface for ordering conditions in a BDD. + */ +@FunctionalInterface +interface OrderingStrategy { + /** + * Orders the given conditions for BDD construction. + * + * @param conditions array of conditions to order + * @return ordered list of conditions + */ + List orderConditions(Condition[] conditions); + + /** + * Creates an initial ordering strategy using the given CFG. + * + * @param cfg CFG to process. + * @return the initial ordering strategy. + */ + static OrderingStrategy initialOrdering(Cfg cfg) { + return new InitialOrdering(cfg); + } + + /** + * Fixed ordering strategy that uses a pre-determined order. + */ + static OrderingStrategy fixed(List ordering) { + return conditions -> ordering; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java new file mode 100644 index 00000000000..a40f8484702 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java @@ -0,0 +1,593 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.logging.Logger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.rulesengine.logic.cfg.ConditionDependencyGraph; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; +import software.amazon.smithy.utils.SmithyBuilder; + +/** + * BDD optimization using tiered parallel position evaluation with dependency-aware constraints. + * + *

The optimization runs in three stages with decreasing granularity: + *

    + *
  • Coarse: Fast reduction with large steps
  • + *
  • Medium: Balanced optimization
  • + *
  • Granular: Fine-tuned optimization for maximum reduction
  • + *
+ */ +public final class SiftingOptimization implements Function { + private static final Logger LOGGER = Logger.getLogger(SiftingOptimization.class.getName()); + + // When to use a parallel stream + private static final int PARALLEL_THRESHOLD = 7; + + // Thread-local BDD builders to avoid allocation overhead + private final ThreadLocal threadBuilder = ThreadLocal.withInitial(BddBuilder::new); + + private final Cfg cfg; + private final ConditionDependencyGraph dependencyGraph; + + // Tiered optimization settings + private final int coarseMinNodes; + private final int coarseMaxPasses; + private final int mediumMinNodes; + private final int mediumMaxPasses; + private final int granularMaxNodes; + private final int granularMaxPasses; + + // Internal effort levels for the tiered optimization stages. + private enum OptimizationEffort { + COARSE(11, 4, 0, 20, 4_000, 6), + MEDIUM(2, 20, 6, 20, 1_000, 6), + GRANULAR(1, 50, 12, 20, 8_000, 12); + + final int sampleRate; + final int maxPositions; + final int nearbyRadius; + final int exhaustiveThreshold; + final int defaultNodeThreshold; + final int defaultMaxPasses; + + OptimizationEffort( + int sampleRate, + int maxPositions, + int nearbyRadius, + int exhaustiveThreshold, + int defaultNodeThreshold, + int defaultMaxPasses + ) { + this.sampleRate = sampleRate; + this.maxPositions = maxPositions; + this.nearbyRadius = nearbyRadius; + this.exhaustiveThreshold = exhaustiveThreshold; + this.defaultNodeThreshold = defaultNodeThreshold; + this.defaultMaxPasses = defaultMaxPasses; + } + } + + private SiftingOptimization(Builder builder) { + this.cfg = SmithyBuilder.requiredState("cfg", builder.cfg); + this.coarseMinNodes = builder.coarseMinNodes; + this.coarseMaxPasses = builder.coarseMaxPasses; + this.mediumMinNodes = builder.mediumMinNodes; + this.mediumMaxPasses = builder.mediumMaxPasses; + this.granularMaxNodes = builder.granularMaxNodes; + this.granularMaxPasses = builder.granularMaxPasses; + this.dependencyGraph = new ConditionDependencyGraph(Arrays.asList(cfg.getConditions())); + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public EndpointBddTrait apply(EndpointBddTrait trait) { + try { + return doApply(trait); + } finally { + threadBuilder.remove(); + } + } + + private EndpointBddTrait doApply(EndpointBddTrait trait) { + LOGGER.info("Starting BDD sifting optimization"); + long startTime = System.currentTimeMillis(); + OptimizationState state = initializeOptimization(trait); + LOGGER.info(String.format("Initial size: %d nodes", state.initialSize)); + + state = runOptimizationStage("Coarse", state, OptimizationEffort.COARSE, coarseMinNodes, coarseMaxPasses, 4.0); + state = runOptimizationStage("Medium", state, OptimizationEffort.MEDIUM, mediumMinNodes, mediumMaxPasses, 1.5); + if (state.currentSize <= granularMaxNodes) { + state = runOptimizationStage("Granular", state, OptimizationEffort.GRANULAR, 0, granularMaxPasses, 0.0); + } else { + LOGGER.info("Skipping granular stage - too large"); + } + state = runAdjacentSwaps(state); + + double totalTimeInSeconds = (System.currentTimeMillis() - startTime) / 1000.0; + + if (state.bestSize >= state.initialSize) { + LOGGER.info(String.format("No improvements found in %fs", totalTimeInSeconds)); + return trait; + } + + LOGGER.info(String.format("Optimization complete: %d -> %d nodes (%.1f%% total reduction) in %fs", + state.initialSize, + state.bestSize, + (1.0 - (double) state.bestSize / state.initialSize) * 100, + totalTimeInSeconds)); + + return trait.toBuilder().conditions(state.orderView).results(state.results).bdd(state.bestBdd).build(); + } + + private OptimizationState initializeOptimization(EndpointBddTrait trait) { + // Use the trait's existing ordering as the starting point + List initialOrder = new ArrayList<>(trait.getConditions()); + Condition[] order = initialOrder.toArray(new Condition[0]); + List orderView = Arrays.asList(order); + Bdd bdd = trait.getBdd(); + int initialSize = bdd.getNodeCount() - 1; + return new OptimizationState(order, orderView, bdd, initialSize, initialSize, trait.getResults()); + } + + private OptimizationState runOptimizationStage( + String stageName, + OptimizationState state, + OptimizationEffort effort, + int targetNodeCount, + int maxPasses, + double minReductionPercent + ) { + if (targetNodeCount > 0 && state.currentSize <= targetNodeCount) { + return state; + } + + LOGGER.info(String.format("Stage: %s optimization (%d nodes%s)", + stageName, + state.currentSize, + targetNodeCount > 0 ? String.format(", target < %d", targetNodeCount) : "")); + + OptimizationState currentState = state; + for (int pass = 1; pass <= maxPasses; pass++) { + if (targetNodeCount > 0 && currentState.currentSize <= targetNodeCount) { + break; + } + + int passStartSize = currentState.currentSize; + OptimizationResult result = runPass(currentState, effort); + if (result.improved) { + currentState = currentState.withResult(result.bdd, result.size, result.results); + double reduction = (1.0 - (double) result.size / passStartSize) * 100; + LOGGER.fine(String.format("%s pass %d: %d -> %d nodes (%.1f%% reduction)", + stageName, + pass, + passStartSize, + result.size, + reduction)); + if (minReductionPercent > 0 && reduction < minReductionPercent) { + LOGGER.fine(String.format("%s optimization yielding diminishing returns", stageName)); + break; + } + } else { + LOGGER.fine(String.format("%s pass %d found no improvements", stageName, pass)); + break; + } + } + + return currentState; + } + + private OptimizationState runAdjacentSwaps(OptimizationState state) { + if (state.currentSize > granularMaxNodes) { + return state; + } + + LOGGER.info("Running adjacent swaps optimization"); + OptimizationState currentState = state; + + // Run multiple sweeps until no improvement + for (int sweep = 1; sweep <= 3; sweep++) { + OptimizationContext context = new OptimizationContext(currentState, dependencyGraph); + int startSize = currentState.currentSize; + + for (int i = 0; i < currentState.order.length - 1; i++) { + if (context.constraints.canMove(i, i + 1)) { + move(currentState.order, i, i + 1); + BddCompilationResult compilationResult = compileBddWithResults(currentState.orderView); + int swappedSize = compilationResult.bdd.getNodeCount() - 1; + if (swappedSize < context.bestSize) { + context = context.withImprovement( + new PositionResult(i + 1, + swappedSize, + compilationResult.bdd, + compilationResult.results)); + } else { + move(currentState.order, i + 1, i); // Swap back + } + } + } + + if (context.improvements > 0) { + currentState = currentState.withResult(context.bestBdd, context.bestSize, context.bestResults); + LOGGER.fine(String.format("Adjacent swaps sweep %d: %d -> %d nodes", + sweep, + startSize, + context.bestSize)); + } else { + break; + } + } + + return currentState; + } + + private OptimizationResult runPass(OptimizationState state, OptimizationEffort effort) { + OptimizationContext context = new OptimizationContext(state, dependencyGraph); + + List selectedConditions = IntStream.range(0, state.orderView.size()) + .filter(i -> i % effort.sampleRate == 0) + .mapToObj(state.orderView::get) + .collect(Collectors.toList()); + + for (Condition condition : selectedConditions) { + Integer varIdx = context.liveIndex.get(condition); + if (varIdx == null) { + continue; + } + + List positions = getStrategicPositions(varIdx, context.constraints, effort); + if (positions.isEmpty()) { + continue; + } + + context = tryImprovePosition(context, varIdx, positions); + } + + return context.toResult(); + } + + private OptimizationContext tryImprovePosition(OptimizationContext context, int varIdx, List positions) { + PositionResult best = findBestPosition(positions, context, varIdx); + if (best != null && best.count <= context.bestSize) { // Accept ties + move(context.order, varIdx, best.position); + return context.withImprovement(best); + } + + return context; + } + + private PositionResult findBestPosition(List positions, OptimizationContext ctx, int varIdx) { + return (positions.size() > PARALLEL_THRESHOLD ? positions.parallelStream() : positions.stream()) + .map(pos -> { + Condition[] order = ctx.order.clone(); + move(order, varIdx, pos); + BddCompilationResult cr = compileBddWithResults(Arrays.asList(order)); + return new PositionResult(pos, cr.bdd.getNodeCount() - 1, cr.bdd, cr.results); + }) + .filter(pr -> pr.count <= ctx.bestSize) + .min(Comparator.comparingInt((PositionResult pr) -> pr.count).thenComparingInt(pr -> pr.position)) + .orElse(null); + } + + private static List getStrategicPositions( + int varIdx, + ConditionDependencyGraph.OrderConstraints constraints, + OptimizationEffort effort + ) { + int min = constraints.getMinValidPosition(varIdx); + int max = constraints.getMaxValidPosition(varIdx); + int range = max - min; + + if (range <= effort.exhaustiveThreshold) { + List positions = new ArrayList<>(range); + for (int p = min; p < max; p++) { + if (p != varIdx && constraints.canMove(varIdx, p)) { + positions.add(p); + } + } + return positions; + } + + List positions = new ArrayList<>(effort.maxPositions); + + // Test extremes first since they often yield the best improvements + if (min != varIdx && constraints.canMove(varIdx, min)) { + positions.add(min); + } + if (positions.size() >= effort.maxPositions) { + return positions; + } + + if (max - 1 != varIdx && constraints.canMove(varIdx, max - 1)) { + positions.add(max - 1); + } + if (positions.size() >= effort.maxPositions) { + return positions; + } + + // Test local moves that preserve relative ordering with neighbors + for (int offset = -effort.nearbyRadius; offset <= effort.nearbyRadius; offset++) { + if (offset != 0) { + if (positions.size() >= effort.maxPositions) { + return positions; + } + int p = varIdx + offset; + if (p >= min && p < max && !positions.contains(p) && constraints.canMove(varIdx, p)) { + positions.add(p); + } + } + } + + // Sample intermediate positions to find global improvements + if (positions.size() >= effort.maxPositions) { + return positions; + } + + int maxSamples = Math.min(15, effort.maxPositions / 2); + int samples = Math.min(maxSamples, Math.max(2, range / 4)); + int step = Math.max(1, range / samples); + + for (int p = min + step; p < max - step && positions.size() < effort.maxPositions; p += step) { + if (p != varIdx && !positions.contains(p) && constraints.canMove(varIdx, p)) { + positions.add(p); + } + } + return positions; + } + + private static void move(Condition[] arr, int from, int to) { + if (from == to) { + return; + } + + Condition moving = arr[from]; + if (from < to) { + System.arraycopy(arr, from + 1, arr, from, to - from); + } else { + System.arraycopy(arr, to, arr, to + 1, from - to); + } + arr[to] = moving; + } + + private static Map rebuildIndex(List orderView) { + Map index = new IdentityHashMap<>(); + for (int i = 0; i < orderView.size(); i++) { + index.put(orderView.get(i), i); + } + return index; + } + + private BddCompilationResult compileBddWithResults(List ordering) { + BddBuilder builder = threadBuilder.get().reset(); + BddCompiler compiler = new BddCompiler(cfg, OrderingStrategy.fixed(ordering), builder); + Bdd bdd = compiler.compile(); + return new BddCompilationResult(bdd, compiler.getIndexedResults()); + } + + // Helper class to track optimization context within a pass + private static final class OptimizationContext { + final Condition[] order; + final List orderView; + final ConditionDependencyGraph dependencyGraph; + final ConditionDependencyGraph.OrderConstraints constraints; + final Map liveIndex; + final Bdd bestBdd; + final int bestSize; + final List bestResults; + final int improvements; + + OptimizationContext(OptimizationState state, ConditionDependencyGraph dependencyGraph) { + this.order = state.order; + this.orderView = state.orderView; + this.dependencyGraph = dependencyGraph; + this.constraints = dependencyGraph.createOrderConstraints(orderView); + this.liveIndex = rebuildIndex(orderView); + this.bestBdd = null; + this.bestSize = state.currentSize; + this.bestResults = null; + this.improvements = 0; + } + + private OptimizationContext( + Condition[] order, + List orderView, + ConditionDependencyGraph dependencyGraph, + ConditionDependencyGraph.OrderConstraints constraints, + Map liveIndex, + Bdd bestBdd, + int bestSize, + List bestResults, + int improvements + ) { + this.order = order; + this.orderView = orderView; + this.dependencyGraph = dependencyGraph; + this.constraints = constraints; + this.liveIndex = liveIndex; + this.bestBdd = bestBdd; + this.bestSize = bestSize; + this.bestResults = bestResults; + this.improvements = improvements; + } + + OptimizationContext withImprovement(PositionResult result) { + ConditionDependencyGraph.OrderConstraints newConstraints = + dependencyGraph.createOrderConstraints(orderView); + Map newIndex = rebuildIndex(orderView); + return new OptimizationContext(order, + orderView, + dependencyGraph, + newConstraints, + newIndex, + result.bdd, + result.count, + result.results, + improvements + 1); + } + + OptimizationResult toResult() { + return new OptimizationResult(bestBdd, bestSize, improvements > 0, bestResults); + } + } + + private static final class BddCompilationResult { + final Bdd bdd; + final List results; + + BddCompilationResult(Bdd bdd, List results) { + this.bdd = bdd; + this.results = results; + } + } + + private static final class PositionResult { + final int position; + final int count; + final Bdd bdd; + final List results; + + PositionResult(int position, int count, Bdd bdd, List results) { + this.position = position; + this.count = count; + this.bdd = bdd; + this.results = results; + } + } + + private static final class OptimizationResult { + final Bdd bdd; + final int size; + final boolean improved; + final List results; + + OptimizationResult(Bdd bdd, int size, boolean improved, List results) { + this.bdd = bdd; + this.size = size; + this.improved = improved; + this.results = results; + } + } + + private static final class OptimizationState { + final Condition[] order; + final List orderView; + final Bdd bestBdd; + final int currentSize; + final int bestSize; + final int initialSize; + final List results; + + OptimizationState( + Condition[] order, + List orderView, + Bdd bestBdd, + int currentSize, + int initialSize, + List results + ) { + this.order = order; + this.orderView = orderView; + this.bestBdd = bestBdd; + this.currentSize = currentSize; + this.bestSize = currentSize; + this.initialSize = initialSize; + this.results = results; + } + + OptimizationState withResult(Bdd newBdd, int newSize, List newResults) { + return new OptimizationState(order, orderView, newBdd, newSize, initialSize, newResults); + } + } + + public static final class Builder implements SmithyBuilder { + private Cfg cfg; + private int coarseMinNodes = OptimizationEffort.COARSE.defaultNodeThreshold; + private int coarseMaxPasses = OptimizationEffort.COARSE.defaultMaxPasses; + private int mediumMinNodes = OptimizationEffort.MEDIUM.defaultNodeThreshold; + private int mediumMaxPasses = OptimizationEffort.MEDIUM.defaultMaxPasses; + private int granularMaxNodes = OptimizationEffort.GRANULAR.defaultNodeThreshold; + private int granularMaxPasses = OptimizationEffort.GRANULAR.defaultMaxPasses; + + private Builder() {} + + /** + * Sets the required control flow graph to optimize. + * + * @param cfg the control flow graph + * @return this builder + */ + public Builder cfg(Cfg cfg) { + this.cfg = cfg; + return this; + } + + /** + * Sets the coarse optimization parameters. + * + *

Coarse optimization runs until the BDD has fewer than minNodeCount nodes + * or maxPasses have been completed. + * + * @param minNodeCount the target size to stop coarse optimization (default: 4,000) + * @param maxPasses the maximum number of coarse passes (default: 6) + * @return this builder + */ + public Builder coarseEffort(int minNodeCount, int maxPasses) { + this.coarseMinNodes = minNodeCount; + this.coarseMaxPasses = maxPasses; + return this; + } + + /** + * Sets the medium optimization parameters. + * + *

Medium optimization runs until the BDD has fewer than minNodeCount nodes + * or maxPasses have been completed. + * + * @param minNodeCount the target size to stop medium optimization (default: 1,000) + * @param maxPasses the maximum number of medium passes (default: 6) + * @return this builder + */ + public Builder mediumEffort(int minNodeCount, int maxPasses) { + this.mediumMinNodes = minNodeCount; + this.mediumMaxPasses = maxPasses; + return this; + } + + /** + * Sets the granular optimization parameters. + * + *

Granular optimization only runs if the BDD has fewer than maxNodeCount nodes, + * and runs for at most maxPasses. + * + * @param maxNodeCount the maximum size to attempt granular optimization (default: 8,000) + * @param maxPasses the maximum number of granular passes (default: 12) + * @return this builder + */ + public Builder granularEffort(int maxNodeCount, int maxPasses) { + this.granularMaxNodes = maxNodeCount; + this.granularMaxPasses = maxPasses; + return this; + } + + @Override + public SiftingOptimization build() { + return new SiftingOptimization(this); + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/UniqueTable.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/UniqueTable.java new file mode 100644 index 00000000000..d83e7843509 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/UniqueTable.java @@ -0,0 +1,74 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.HashMap; +import java.util.Map; + +/** + * A specialized hash table for BDD node deduplication using triple (var, high, low) keys. + */ +final class UniqueTable { + private final Map table; + private final TripleKey mutableKey = new TripleKey(0, 0, 0); + + public UniqueTable() { + this.table = new HashMap<>(); + } + + public UniqueTable(int initialCapacity) { + this.table = new HashMap<>(initialCapacity); + } + + public Integer get(int var, int high, int low) { + mutableKey.update(var, high, low); + return table.get(mutableKey); + } + + public void put(int var, int high, int low, int nodeIndex) { + table.put(new TripleKey(var, high, low), nodeIndex); + } + + public void clear() { + table.clear(); + } + + public int size() { + return table.size(); + } + + private static final class TripleKey { + private int a, b, c, hash; + + private TripleKey(int a, int b, int c) { + update(a, b, c); + } + + TripleKey update(int a, int b, int c) { + this.a = a; + this.b = b; + this.c = c; + int i = (a * 31 + b) * 31 + c; + this.hash = (i ^ (i >>> 16)); + return this; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (!(o instanceof TripleKey)) { + return false; + } + TripleKey k = (TripleKey) o; + return a == k.a && b == k.b && c == k.c; + } + + @Override + public int hashCode() { + return hash; + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/Cfg.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/Cfg.java new file mode 100644 index 00000000000..3924242ef47 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/Cfg.java @@ -0,0 +1,318 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.Set; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.RulesVersion; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.utils.SmithyBuilder; + +/** + * A Control Flow Graph (CFG) representation of endpoint rule decision logic. + * + *

The CFG transforms the hierarchical decision tree structure into an optimized + * representation with node deduplication to prevent exponential growth. + * + *

The CFG consists of: + *

    + *
  • A root node representing the entry point of the decision logic
  • + *
  • A DAG structure where condition nodes are shared when they have identical subtrees
  • + *
+ */ +public final class Cfg implements Iterable { + + private final Parameters parameters; + private final CfgNode root; + + // Lazily computed condition data + private Condition[] conditions; + private Map conditionToIndex; + private final RulesVersion version; + + Cfg(EndpointRuleSet ruleSet, CfgNode root) { + this( + ruleSet == null ? Parameters.builder().build() : ruleSet.getParameters(), + root, + ruleSet == null ? RulesVersion.V1_1 : ruleSet.getRulesVersion()); + } + + Cfg(Parameters parameters, CfgNode root, RulesVersion version) { + this.root = SmithyBuilder.requiredState("root", root); + this.version = version; + this.parameters = parameters; + } + + /** + * Create a CFG from the given ruleset. + * + * @param ruleSet Rules to convert to CFG. + * @return the CFG result. + */ + public static Cfg from(EndpointRuleSet ruleSet) { + CfgBuilder builder = new CfgBuilder(ruleSet); + CfgNode terminal = ResultNode.terminal(); + Map processedRules = new HashMap<>(); + CfgNode root = convertRulesToChain(builder.ruleSet.getRules(), terminal, builder, processedRules); + return builder.build(root); + } + + /** + * Get the endpoint ruleset version of the CFG. + * + * @return endpoint ruleset version. + */ + public RulesVersion getVersion() { + return version; + } + + /** + * Gets all unique conditions in the CFG, in the order they were discovered. + * + * @return array of conditions + */ + public Condition[] getConditions() { + ensureConditionsExtracted(); + return conditions; + } + + /** + * Gets the index of a condition in the conditions array. + * + * @param condition the condition to look up + * @return the index, or null if not found + */ + public Integer getConditionIndex(Condition condition) { + ensureConditionsExtracted(); + return conditionToIndex.get(condition); + } + + /** + * Gets the number of unique conditions in the CFG. + * + * @return the condition count + */ + public int getConditionCount() { + ensureConditionsExtracted(); + return conditions.length; + } + + private void ensureConditionsExtracted() { + if (conditions == null) { + extractConditions(); + } + } + + private synchronized void extractConditions() { + if (conditions != null) { + return; + } + + List conditionList = new ArrayList<>(); + Map indexMap = new LinkedHashMap<>(); + + for (CfgNode node : this) { + if (node instanceof ConditionNode) { + ConditionNode condNode = (ConditionNode) node; + Condition condition = condNode.getCondition().getCondition(); + + if (!indexMap.containsKey(condition)) { + indexMap.put(condition, conditionList.size()); + conditionList.add(condition); + } + } + } + + this.conditions = conditionList.toArray(new Condition[0]); + this.conditionToIndex = indexMap; + } + + public Parameters getParameters() { + return parameters; + } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } else if (object == null || getClass() != object.getClass()) { + return false; + } else { + Cfg o = (Cfg) object; + return root.equals(o.root) && version.equals(o.version); + } + } + + @Override + public int hashCode() { + return Objects.hash(root, version); + } + + /** + * Returns the root node of the control flow graph. + * + * @return the root node + */ + public CfgNode getRoot() { + return root; + } + + @Override + public Iterator iterator() { + return new Iterator() { + private final Deque stack = new ArrayDeque<>(); + private final Set visited = new HashSet<>(); + private CfgNode next; + + { + if (root != null) { + stack.push(root); + } + advance(); + } + + private void advance() { + next = null; + while (!stack.isEmpty()) { + CfgNode node = stack.pop(); + if (visited.add(node)) { + // Push children before returning this node + if (node instanceof ConditionNode) { + ConditionNode cond = (ConditionNode) node; + stack.push(cond.getFalseBranch()); + stack.push(cond.getTrueBranch()); + } + next = node; + return; + } + } + } + + @Override + public boolean hasNext() { + return next != null; + } + + @Override + public CfgNode next() { + if (next == null) { + throw new NoSuchElementException(); + } + CfgNode result = next; + advance(); + return result; + } + }; + } + + // Converts a list of rules into a conditional chain. Each rule's false branch goes to the next rule. + private static CfgNode convertRulesToChain( + List rules, + CfgNode fallthrough, + CfgBuilder builder, + Map processedRules + ) { + // Make a reversed view of the rules list + List reversed = new ArrayList<>(rules); + Collections.reverse(reversed); + CfgNode next = fallthrough; + for (Rule rule : reversed) { + next = convertRule(rule, next, builder, processedRules); + } + return next; + } + + /** + * Converts a single rule to CFG nodes. + * + * @param rule the rule to convert + * @param fallthrough what to do if this rule doesn't match + * @param builder the CFG builder + * @param processedRules cache for processed rules + * @return the entry point for this rule + */ + private static CfgNode convertRule( + Rule rule, + CfgNode fallthrough, + CfgBuilder builder, + Map processedRules + ) { + RuleKey key = new RuleKey(rule, fallthrough); + CfgNode existing = processedRules.get(key); + if (existing != null) { + return existing; + } + + CfgNode body; + if (rule instanceof EndpointRule || rule instanceof ErrorRule) { + body = builder.createResult(rule); + } else if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + // Recursively convert nested rules with same fallthrough + body = convertRulesToChain(treeRule.getRules(), fallthrough, builder, processedRules); + } else { + throw new IllegalArgumentException("Unknown rule type: " + rule.getClass()); + } + + // Build conditions from last to first + CfgNode current = body; + for (int i = rule.getConditions().size() - 1; i >= 0; i--) { + Condition cond = rule.getConditions().get(i); + // For chained conditions (AND semantics), if one fails, we go to the fallthrough + current = builder.createCondition(cond, current, fallthrough); + } + + // Cache the result for this (rule, fallthrough) combination + processedRules.put(key, current); + + return current; + } + + private static final class RuleKey { + private final Rule rule; + private final CfgNode fallthrough; + private final int hashCode; + + RuleKey(Rule rule, CfgNode fallthrough) { + this.rule = rule; + this.fallthrough = fallthrough; + // Use identity hash for fallthrough since it's a node reference + this.hashCode = System.identityHashCode(rule) * 31 + System.identityHashCode(fallthrough); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (!(o instanceof RuleKey)) { + return false; + } + RuleKey that = (RuleKey) o; + return rule == that.rule && fallthrough == that.fallthrough; + } + + @Override + public int hashCode() { + return hashCode; + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java new file mode 100644 index 00000000000..1cadcce2525 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java @@ -0,0 +1,210 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Not; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.ConditionReference; + +/** + * Builder for constructing Control Flow Graphs with node deduplication. + * + *

This builder performs hash-consing during construction to share identical + * subtrees and prevent exponential growth. + */ +public final class CfgBuilder { + + final EndpointRuleSet ruleSet; + + // Node deduplication + private final Map nodeCache = new HashMap<>(); + + // Condition and result canonicalization + private final Map conditionToReference = new HashMap<>(); + private final Map resultCache = new HashMap<>(); + private final Map resultNodeCache = new HashMap<>(); + + public CfgBuilder(EndpointRuleSet ruleSet) { + // Apply SSA transform to ensure globally unique variable names + this.ruleSet = SsaTransform.transform(ruleSet); + } + + /** + * Build the CFG with the given root node. + * + * @param root Root node to use for the built CFG. + * @return the built CFG. + */ + public Cfg build(CfgNode root) { + return new Cfg(ruleSet, Objects.requireNonNull(root)); + } + + /** + * Creates a condition node, reusing existing nodes when possible. + */ + public CfgNode createCondition(Condition condition, CfgNode trueBranch, CfgNode falseBranch) { + return createCondition(createConditionReference(condition), trueBranch, falseBranch); + } + + /** + * Creates a condition node, reusing existing nodes when possible. + */ + public CfgNode createCondition(ConditionReference condRef, CfgNode trueBranch, CfgNode falseBranch) { + NodeSignature signature = new NodeSignature(condRef, trueBranch, falseBranch); + return nodeCache.computeIfAbsent(signature, key -> new ConditionNode(condRef, trueBranch, falseBranch)); + } + + /** + * Creates a result node representing a terminal rule evaluation. + */ + public CfgNode createResult(Rule rule) { + // Intern the result + Rule interned = intern(rule); + + // Regular result node + return resultNodeCache.computeIfAbsent(interned, ResultNode::new); + } + + /** + * Creates a canonical condition reference, handling negation and deduplication. + */ + public ConditionReference createConditionReference(Condition condition) { + ConditionReference cached = conditionToReference.get(condition); + if (cached != null) { + return cached; + } + + boolean negated = false; + Condition canonical = condition; + + if (isNegationWrapper(condition)) { + negated = true; + canonical = unwrapNegation(condition); + + ConditionReference existing = conditionToReference.get(canonical); + if (existing != null) { + ConditionReference negatedReference = existing.negate(); + conditionToReference.put(condition, negatedReference); + return negatedReference; + } + } + + canonical = canonical.canonicalize(); + + Condition beforeBooleanCanon = canonical; + canonical = canonicalizeBooleanEquals(canonical); + + if (!canonical.equals(beforeBooleanCanon)) { + negated = !negated; + } + + ConditionReference reference = new ConditionReference(canonical, negated); + conditionToReference.put(condition, reference); + + if (!negated && !condition.equals(canonical)) { + conditionToReference.put(canonical, reference); + } + + return reference; + } + + private Rule intern(Rule rule) { + return resultCache.computeIfAbsent(canonicalizeResult(rule), k -> k); + } + + private Rule canonicalizeResult(Rule rule) { + return rule == null ? null : rule.withConditions(Collections.emptyList()); + } + + private Condition canonicalizeBooleanEquals(Condition condition) { + if (!(condition.getFunction() instanceof BooleanEquals)) { + return condition; + } + + List args = condition.getFunction().getArguments(); + if (args.size() != 2 || !(args.get(0) instanceof Reference) || !(args.get(1) instanceof Literal)) { + return condition; + } + + Reference ref = (Reference) args.get(0); + Boolean literalValue = ((Literal) args.get(1)).asBooleanLiteral().orElse(null); + + if (literalValue != null && !literalValue && ruleSet != null) { + String varName = ref.getName().toString(); + Optional param = ruleSet.getParameters().get(Identifier.of(varName)); + if (param.isPresent() && param.get().getDefault().isPresent()) { + return condition.toBuilder().fn(BooleanEquals.ofExpressions(ref, true)).build(); + } + } + + return condition; + } + + private static boolean isNegationWrapper(Condition condition) { + return condition.getFunction() instanceof Not + && !condition.getResult().isPresent() + && condition.getFunction().getArguments().get(0) instanceof LibraryFunction; + } + + private static Condition unwrapNegation(Condition negatedCondition) { + return negatedCondition.toBuilder() + .fn((LibraryFunction) negatedCondition.getFunction().getArguments().get(0)) + .build(); + } + + /** + * Signature for node deduplication during construction. + */ + private static final class NodeSignature { + private final ConditionReference condition; + private final CfgNode trueBranch; + private final CfgNode falseBranch; + private final int hashCode; + + NodeSignature(ConditionReference condition, CfgNode trueBranch, CfgNode falseBranch) { + this.condition = condition; + this.trueBranch = trueBranch; + this.falseBranch = falseBranch; + this.hashCode = Objects.hash( + condition, + System.identityHashCode(trueBranch), + System.identityHashCode(falseBranch)); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof NodeSignature)) { + return false; + } + NodeSignature that = (NodeSignature) o; + return Objects.equals(condition, that.condition) + && trueBranch == that.trueBranch + && falseBranch == that.falseBranch; + } + + @Override + public int hashCode() { + return hashCode; + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgNode.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgNode.java new file mode 100644 index 00000000000..230a7d4483b --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgNode.java @@ -0,0 +1,13 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +/** + * Abstract base class for all nodes in a Control Flow Graph (CFG). + */ +public abstract class CfgNode { + // Package-private "sealed" class. + CfgNode() {} +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransform.java new file mode 100644 index 00000000000..365ebe3e77c --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransform.java @@ -0,0 +1,241 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.logging.Level; +import java.util.logging.Logger; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType; +import software.amazon.smithy.rulesengine.language.evaluation.type.RecordType; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; + +/** + * Coalesces bind-then-use patterns in conditions, identifying conditions that bind a variable followed immediately by + * a condition that uses that variable, and merges them using coalesce. + */ +final class CoalesceTransform { + private static final Logger LOGGER = Logger.getLogger(CoalesceTransform.class.getName()); + + private final Map coalesceCache = new HashMap<>(); + private int coalesceCount = 0; + private int cacheHits = 0; + private int skippedNoZeroValue = 0; + private int skippedMultipleUses = 0; + private final Set skippedRecordTypes = new HashSet<>(); + + static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + CoalesceTransform transform = new CoalesceTransform(); + + List transformedRules = new ArrayList<>(); + for (int i = 0; i < ruleSet.getRules().size(); i++) { + transformedRules.add(transform.transformRule(ruleSet.getRules().get(i), "root/rule[" + i + "]")); + } + + if (LOGGER.isLoggable(Level.INFO)) { + LOGGER.info(String.format( + "Coalescing: %d coalesced, %d cache hits, %d skipped (no zero), %d skipped (multiple uses)", + transform.coalesceCount, + transform.cacheHits, + transform.skippedNoZeroValue, + transform.skippedMultipleUses)); + } + + return EndpointRuleSet.builder() + .parameters(ruleSet.getParameters()) + .rules(transformedRules) + .version(ruleSet.getVersion()) + .build(); + } + + private Rule transformRule(Rule rule, String rulePath) { + // Count local usage for THIS rule's conditions + Map localVarUsage = new HashMap<>(); + for (Condition condition : rule.getConditions()) { + for (String ref : condition.getFunction().getReferences()) { + localVarUsage.merge(ref, 1, Integer::sum); + } + } + + Set eliminatedConditions = new HashSet<>(); + List transformedConditions = transformConditions( + rule.getConditions(), + eliminatedConditions, + localVarUsage); + + if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + return TreeRule.builder() + .description(rule.getDocumentation().orElse(null)) + .conditions(transformedConditions) + .treeRule(TreeRewriter.transformNestedRules(treeRule, rulePath, this::transformRule)); + } + + // CoalesceTransform only modifies conditions, not endpoints/errors + return rule.withConditions(transformedConditions); + } + + private List transformConditions( + List conditions, + Set eliminatedConditions, + Map localVarUsage + ) { + List result = new ArrayList<>(); + + for (int i = 0; i < conditions.size(); i++) { + Condition current = conditions.get(i); + if (eliminatedConditions.contains(current)) { + continue; + } + + if (i + 1 < conditions.size() && current.getResult().isPresent()) { + String var = current.getResult().get().toString(); + Condition next = conditions.get(i + 1); + + if (canCoalesce(var, current, next, localVarUsage)) { + result.add(createCoalescedCondition(current, next, var)); + eliminatedConditions.add(current); + eliminatedConditions.add(next); + i++; // Skip next + continue; + } + } + + result.add(current); + } + + return result; + } + + private boolean canCoalesce(String var, Condition bind, Condition use, Map localVarUsage) { + if (!use.getFunction().getReferences().contains(var)) { + return false; + } + + if (use.getFunction().getFunctionDefinition() == IsSet.getDefinition()) { + return false; + } + + Integer localUses = localVarUsage.get(var); + if (localUses == null || localUses > 1) { + skippedMultipleUses++; + return false; + } + + Type type = bind.getFunction().getFunctionDefinition().getReturnType(); + Type innerType = type instanceof OptionalType ? ((OptionalType) type).inner() : type; + + if (innerType instanceof RecordType) { + skippedNoZeroValue++; + skippedRecordTypes.add(bind.getFunction().getName()); + return false; + } + + if (!innerType.getZeroValue().isPresent()) { + skippedNoZeroValue++; + return false; + } + + return true; + } + + private Condition createCoalescedCondition(Condition bind, Condition use, String var) { + LibraryFunction bindExpr = bind.getFunction(); + LibraryFunction useExpr = use.getFunction(); + + Type type = bindExpr.getFunctionDefinition().getReturnType(); + Type innerType = type instanceof OptionalType ? ((OptionalType) type).inner() : type; + Literal zero = innerType.getZeroValue().get(); + + String bindCanonical = bindExpr.canonicalize().toString(); + String zeroCanonical = zero.toString(); + String useCanonical = useExpr.canonicalize().toString(); + String resultVar = use.getResult().map(Identifier::toString).orElse(""); + + CoalesceKey key = new CoalesceKey(bindCanonical, zeroCanonical, useCanonical, var, resultVar); + Condition cached = coalesceCache.get(key); + if (cached != null) { + cacheHits++; + return cached; + } + + Expression coalesced = Coalesce.ofExpressions(bindExpr, zero); + Map replacements = new HashMap<>(); + replacements.put(var, coalesced); + + Expression replaced = TreeRewriter.forReplacements(replacements).rewrite(useExpr); + LibraryFunction canonicalized = ((LibraryFunction) replaced).canonicalize(); + + Condition.Builder builder = Condition.builder().fn(canonicalized); + if (use.getResult().isPresent()) { + builder.result(use.getResult().get()); + } + + Condition result = builder.build(); + coalesceCache.put(key, result); + coalesceCount++; + + if (LOGGER.isLoggable(Level.FINE)) { + LOGGER.fine("Coalesced #" + coalesceCount + ":\n" + + " " + var + " = " + bind.getFunction() + "\n" + + " " + use.getFunction() + "\n" + + " => " + canonicalized); + } + + return result; + } + + private static final class CoalesceKey { + final String bindFunction; + final String zeroValue; + final String useFunction; + final String replacedVar; + final String resultVar; + final int hashCode; + + CoalesceKey(String bindFunction, String zeroValue, String useFunction, String replacedVar, String resultVar) { + this.bindFunction = bindFunction; + this.zeroValue = zeroValue; + this.useFunction = useFunction; + this.replacedVar = replacedVar; + this.resultVar = resultVar; + this.hashCode = Objects.hash(bindFunction, zeroValue, useFunction, replacedVar, resultVar); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (!(o instanceof CoalesceKey)) { + return false; + } + CoalesceKey that = (CoalesceKey) o; + return bindFunction.equals(that.bindFunction) && zeroValue.equals(that.zeroValue) + && useFunction.equals(that.useFunction) + && replacedVar.equals(that.replacedVar) + && resultVar.equals(that.resultVar); + } + + @Override + public int hashCode() { + return hashCode; + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDependencyGraph.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDependencyGraph.java new file mode 100644 index 00000000000..5b4730f3633 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDependencyGraph.java @@ -0,0 +1,315 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; + +/** + * Graph of dependencies between conditions based on variable definitions and usage. + * + *

This class performs AST analysis once to extract: + *

    + *
  • Variable definitions - which conditions define which variables
  • + *
  • Variable usage - which conditions use which variables
  • + *
  • Dependencies - which conditions must come before others
  • + *
+ */ +public final class ConditionDependencyGraph { + private final List conditions; + private final Map conditionToIndex; + private final Map> dependencies; + private final Map> variableDefiners; + private final Map> isSetConditions; + + // Indexed dependency information for fast access + private final List> predecessors; + private final List> successors; + + /** + * Creates a dependency graph by analyzing the given conditions. + * + * @param conditions the conditions to analyze + */ + public ConditionDependencyGraph(List conditions) { + this.conditions = Collections.unmodifiableList(new ArrayList<>(conditions)); + this.conditionToIndex = new HashMap<>(); + this.variableDefiners = new HashMap<>(); + this.isSetConditions = new HashMap<>(); + + int n = conditions.size(); + for (int i = 0; i < n; i++) { + conditionToIndex.put(conditions.get(i), i); + } + + // Initialize indexed structures + this.predecessors = new ArrayList<>(n); + this.successors = new ArrayList<>(n); + for (int i = 0; i < n; i++) { + predecessors.add(new HashSet<>()); + successors.add(new HashSet<>()); + } + + // Categorize all conditions + for (Condition cond : conditions) { + // Track variable definition + if (cond.getResult().isPresent()) { + String definedVar = cond.getResult().get().toString(); + variableDefiners.computeIfAbsent(definedVar, k -> new HashSet<>()).add(cond); + } + + // Track isSet conditions + if (isIsSet(cond)) { + for (String var : cond.getFunction().getReferences()) { + isSetConditions.computeIfAbsent(var, k -> new HashSet<>()).add(cond); + } + } + } + + // Compute dependencies + Map> deps = new HashMap<>(); + Map> producers = new HashMap<>(); + Map> isSetters = new HashMap<>(); + + // Build producer and isSet indices using Identifier + for (int i = 0; i < n; i++) { + Condition c = conditions.get(i); + + if (c.getResult().isPresent()) { + Identifier var = c.getResult().get(); + producers.computeIfAbsent(var, k -> new HashSet<>()).add(i); + } + + if (isIsSet(c)) { + for (String ref : c.getFunction().getReferences()) { + Identifier var = Identifier.of(ref); + isSetters.computeIfAbsent(var, k -> new HashSet<>()).add(i); + } + } + } + + // Build both object-based and index-based dependencies + for (int i = 0; i < n; i++) { + Condition cond = conditions.get(i); + Set condDeps = new HashSet<>(); + + for (String usedVar : cond.getFunction().getReferences()) { + // Object-based dependencies + condDeps.addAll(variableDefiners.getOrDefault(usedVar, Collections.emptySet())); + if (!isIsSet(cond)) { + condDeps.addAll(isSetConditions.getOrDefault(usedVar, Collections.emptySet())); + } + + // Index-based dependencies + Identifier var = Identifier.of(usedVar); + for (int prod : producers.getOrDefault(var, Collections.emptySet())) { + if (prod != i) { + predecessors.get(i).add(prod); + successors.get(prod).add(i); + } + } + + if (!isIsSet(cond)) { + for (int setter : isSetters.getOrDefault(var, Collections.emptySet())) { + if (setter != i) { + predecessors.get(i).add(setter); + successors.get(setter).add(i); + } + } + } + } + + condDeps.remove(cond); // Remove self-dependencies + if (!condDeps.isEmpty()) { + deps.put(cond, Collections.unmodifiableSet(condDeps)); + } + } + + this.dependencies = Collections.unmodifiableMap(deps); + } + + /** + * Gets the dependencies for a condition. + * + * @param condition the condition to query + * @return set of conditions that must come before it (never null) + */ + public Set getDependencies(Condition condition) { + return dependencies.getOrDefault(condition, Collections.emptySet()); + } + + /** + * Gets the predecessors (dependencies) for a condition by index. + * + * @param index the condition index + * @return set of predecessor indices + */ + public Set getPredecessors(int index) { + return predecessors.get(index); + } + + /** + * Gets the successors (dependents) for a condition by index. + * + * @param index the condition index + * @return set of successor indices + */ + public Set getSuccessors(int index) { + return successors.get(index); + } + + /** + * Gets the number of predecessors for a condition. + * + * @param index the condition index + * @return the predecessor count + */ + public int getPredecessorCount(int index) { + return predecessors.get(index).size(); + } + + /** + * Gets the number of successors for a condition. + * + * @param index the condition index + * @return the successor count + */ + public int getSuccessorCount(int index) { + return successors.get(index).size(); + } + + /** + * Checks if there's a dependency from one condition to another. + * + * @param from the dependent condition index + * @param to the dependency condition index + * @return true if 'from' depends on 'to' + */ + public boolean hasDependency(int from, int to) { + return predecessors.get(from).contains(to); + } + + /** + * Creates order constraints for a specific ordering of conditions. + * + * @param ordering the ordering to compute constraints for + * @return the order constraints + */ + public OrderConstraints createOrderConstraints(List ordering) { + return new OrderConstraints(ordering); + } + + /** + * Gets the index mapping for conditions. + * + * @return map from condition to index + */ + public Map getConditionToIndex() { + return Collections.unmodifiableMap(conditionToIndex); + } + + /** + * Gets the number of conditions in this dependency graph. + * + * @return the number of conditions + */ + public int size() { + return conditions.size(); + } + + private static boolean isIsSet(Condition cond) { + return cond.getFunction().getFunctionDefinition() == IsSet.getDefinition(); + } + + /** + * Order-specific constraints for a particular condition ordering. + */ + public final class OrderConstraints { + private final Condition[] orderedConditions; + private final Map orderIndex; + private final int[] minValidPosition; + private final int[] maxValidPosition; + + private OrderConstraints(List ordering) { + int n = ordering.size(); + if (n != conditions.size()) { + throw new IllegalArgumentException( + "Ordering size (" + n + ") doesn't match dependency graph size (" + conditions.size() + ")"); + } + + this.orderedConditions = ordering.toArray(new Condition[0]); + this.orderIndex = new HashMap<>(n * 2); + this.minValidPosition = new int[n]; + this.maxValidPosition = new int[n]; + + // Build index mapping for this ordering + for (int i = 0; i < n; i++) { + orderIndex.put(orderedConditions[i], i); + } + + // Compute valid positions based on dependencies + for (int i = 0; i < n; i++) { + maxValidPosition[i] = n - 1; // Initialize max position + + Condition cond = orderedConditions[i]; + Integer originalIdx = conditionToIndex.get(cond); + if (originalIdx == null) { + throw new IllegalArgumentException("Condition not in dependency graph: " + cond); + } + + // Check all dependencies + for (int depIdx : predecessors.get(originalIdx)) { + Condition depCond = conditions.get(depIdx); + Integer depOrderIdx = orderIndex.get(depCond); + if (depOrderIdx != null) { + // This condition must come after its dependency + minValidPosition[i] = Math.max(minValidPosition[i], depOrderIdx + 1); + // The dependency must come before this condition + maxValidPosition[depOrderIdx] = Math.min(maxValidPosition[depOrderIdx], i - 1); + } + } + } + } + + /** + * Checks if moving a condition from one position to another would violate dependencies. + * + * @param from current position + * @param to target position + * @return true if the move is valid + */ + public boolean canMove(int from, int to) { + return from == to || (to >= minValidPosition[from] && to <= maxValidPosition[from]); + } + + /** + * Gets the minimum valid position for a condition. + * + * @param positionIndex the position index in the ordering + * @return the minimum position where this condition can be placed + */ + public int getMinValidPosition(int positionIndex) { + return minValidPosition[positionIndex]; + } + + /** + * Gets the maximum valid position for a condition. + * + * @param positionIndex the position index in the ordering + * @return the maximum position where this condition can be placed + */ + public int getMaxValidPosition(int positionIndex) { + return maxValidPosition[positionIndex]; + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionNode.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionNode.java new file mode 100644 index 00000000000..c6eb29b5fe0 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionNode.java @@ -0,0 +1,83 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.Objects; +import software.amazon.smithy.rulesengine.logic.ConditionReference; + +/** + * A CFG node that evaluates a condition and branches based on the result. + */ +public final class ConditionNode extends CfgNode { + + private final ConditionReference condition; + private final CfgNode trueBranch; + private final CfgNode falseBranch; + private final int hash; + + /** + * Creates a new condition node. + * + * @param condition condition reference (can be negated) + * @param trueBranch node to evaluate if the condition is true + * @param falseBranch node to evaluate if the condition is false + */ + public ConditionNode(ConditionReference condition, CfgNode trueBranch, CfgNode falseBranch) { + this.condition = Objects.requireNonNull(condition); + this.trueBranch = Objects.requireNonNull(trueBranch, "trueBranch must not be null"); + this.falseBranch = Objects.requireNonNull(falseBranch, "falseBranch must not be null"); + this.hash = Objects.hash(condition, trueBranch, falseBranch); + } + + /** + * Returns the condition reference for this node. + * + * @return the condition reference + */ + public ConditionReference getCondition() { + return condition; + } + + /** + * Returns the node to evaluate if the condition is true. + * + * @return the true branch node + */ + public CfgNode getTrueBranch() { + return trueBranch; + } + + /** + * Returns the node to evaluate if the condition is false. + * + * @return the false branch node + */ + public CfgNode getFalseBranch() { + return falseBranch; + } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } else if (object == null || getClass() != object.getClass()) { + return false; + } + ConditionNode o = (ConditionNode) object; + return condition.equals(o.condition) && trueBranch.equals(o.trueBranch) && falseBranch.equals(o.falseBranch); + } + + @Override + public int hashCode() { + return hash; + } + + @Override + public String toString() { + return "ConditionNode{condition=" + condition + + ", trueBranch=" + System.identityHashCode(trueBranch) + + ", falseBranch=" + System.identityHashCode(falseBranch) + '}'; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ResultNode.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ResultNode.java new file mode 100644 index 00000000000..509e9ee6873 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ResultNode.java @@ -0,0 +1,65 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.Objects; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; + +/** + * A terminal CFG node that represents a final result, either an endpoint or error. + */ +public final class ResultNode extends CfgNode { + private final Rule result; + private final int hash; + private static final ResultNode TERMINAL = new ResultNode(); + + public ResultNode(Rule result) { + this.result = result; + this.hash = result == null ? 11 : result.hashCode(); + } + + private ResultNode() { + this(null); + } + + /** + * Returns a terminal node representing no match. + * + * @return the terminal result node + */ + public static ResultNode terminal() { + return TERMINAL; + } + + /** + * Get the underlying result. + * + * @return the result value. + */ + public Rule getResult() { + return result; + } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } else if (object == null || getClass() != object.getClass()) { + return false; + } else { + return Objects.equals(result, (((ResultNode) object)).result); + } + } + + @Override + public int hashCode() { + return hash; + } + + @Override + public String toString() { + return "ResultNode{hash=" + hash + ", result=" + result + '}'; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java new file mode 100644 index 00000000000..f1da1f16520 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java @@ -0,0 +1,277 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.HashMap; +import java.util.HashSet; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import software.amazon.smithy.rulesengine.language.Endpoint; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; + +/** + * Transforms a decision tree into Static Single Assignment (SSA) form. + * + *

This transformation ensures that each variable is assigned exactly once by renaming variables when they are + * reassigned in different parts of the tree. For example, if variable "x" is assigned in multiple branches, they + * become "x_ssa_1", "x_ssa_2", "x_ssa_3", etc. Without this transform, the BDD compilation would confuse divergent + * paths that have the same variable name. + * + *

Note that this transform is only applied when the reassignment is done using different + * arguments than previously seen assignments of the same variable name. + */ +final class SsaTransform { + + private final Deque> scopeStack = new ArrayDeque<>(); + private final Map rewrittenConditions = new IdentityHashMap<>(); + private final Map rewrittenRules = new IdentityHashMap<>(); + private final VariableAnalysis variableAnalysis; + private final TreeRewriter referenceRewriter; + + private SsaTransform(VariableAnalysis variableAnalysis) { + scopeStack.push(new HashMap<>()); + this.variableAnalysis = variableAnalysis; + this.referenceRewriter = new TreeRewriter(this::referenceRewriter, this::needsRewriting); + } + + private Expression referenceRewriter(Reference ref) { + String originalName = ref.getName().toString(); + String uniqueName = resolveReference(originalName); + return Expression.getReference(Identifier.of(uniqueName)); + } + + static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + ruleSet = VariableConsolidationTransform.transform(ruleSet); + ruleSet = CoalesceTransform.transform(ruleSet); + VariableAnalysis variableAnalysis = VariableAnalysis.analyze(ruleSet); + SsaTransform ssaTransform = new SsaTransform(variableAnalysis); + + List rewrittenRules = new ArrayList<>(ruleSet.getRules().size()); + for (Rule original : ruleSet.getRules()) { + rewrittenRules.add(ssaTransform.processRule(original)); + } + + return EndpointRuleSet.builder() + .parameters(ruleSet.getParameters()) + .rules(rewrittenRules) + .version(ruleSet.getVersion()) + .build(); + } + + private Rule processRule(Rule rule) { + enterScope(); + Rule rewrittenRule = rewriteRule(rule); + exitScope(); + return rewrittenRule; + } + + private void enterScope() { + scopeStack.push(new HashMap<>(scopeStack.peek())); + } + + private void exitScope() { + if (scopeStack.size() <= 1) { + throw new IllegalStateException("Cannot exit global scope"); + } + scopeStack.pop(); + } + + private Condition rewriteCondition(Condition condition) { + boolean hasBinding = condition.getResult().isPresent(); + + if (!hasBinding) { + Condition cached = rewrittenConditions.get(condition); + if (cached != null) { + return cached; + } + } + + LibraryFunction fn = condition.getFunction(); + Set rewritableRefs = filterOutInputParameters(fn.getReferences()); + + String uniqueBindingName = null; + boolean needsUniqueBinding = false; + if (hasBinding) { + String varName = condition.getResult().get().toString(); + + // Only need SSA rename if variable has multiple bindings + if (variableAnalysis.hasMultipleBindings(varName)) { + Map expressionMap = variableAnalysis.getExpressionMappings().get(varName); + if (expressionMap != null) { + uniqueBindingName = expressionMap.get(fn.toString()); + needsUniqueBinding = uniqueBindingName != null && !uniqueBindingName.equals(varName); + } + } + } + + if (!needsRewriting(rewritableRefs) && !needsUniqueBinding) { + if (!hasBinding) { + rewrittenConditions.put(condition, condition); + } + return condition; + } + + LibraryFunction rewrittenExpr = (LibraryFunction) referenceRewriter.rewrite(fn); + boolean exprChanged = rewrittenExpr != fn; + + Condition rewritten; + if (hasBinding && uniqueBindingName != null) { + scopeStack.peek().put(condition.getResult().get().toString(), uniqueBindingName); + if (needsUniqueBinding || exprChanged) { + rewritten = condition.toBuilder().fn(rewrittenExpr).result(Identifier.of(uniqueBindingName)).build(); + } else { + rewritten = condition; + } + } else if (exprChanged) { + rewritten = condition.toBuilder().fn(rewrittenExpr).build(); + } else { + rewritten = condition; + } + + if (!hasBinding) { + rewrittenConditions.put(condition, rewritten); + } + + return rewritten; + } + + private Set filterOutInputParameters(Set references) { + if (references.isEmpty() || variableAnalysis.getInputParams().isEmpty()) { + return references; + } + + Set filtered = new HashSet<>(references); + filtered.removeAll(variableAnalysis.getInputParams()); + return filtered; + } + + private boolean needsRewriting(Set references) { + if (references.isEmpty()) { + return false; + } + + Map currentScope = scopeStack.peek(); + for (String ref : references) { + String mapped = currentScope.get(ref); + if (mapped != null && !mapped.equals(ref)) { + return true; + } + } + return false; + } + + private boolean needsRewriting(Expression expression) { + return needsRewriting(filterOutInputParameters(expression.getReferences())); + } + + private Rule rewriteRule(Rule rule) { + Rule cached = rewrittenRules.get(rule); + if (cached != null) { + return cached; + } + + List rewrittenConditions = rewriteConditions(rule.getConditions()); + boolean conditionsChanged = !rewrittenConditions.equals(rule.getConditions()); + + Rule result; + if (rule instanceof EndpointRule) { + result = rewriteEndpointRule((EndpointRule) rule, rewrittenConditions, conditionsChanged); + } else if (rule instanceof ErrorRule) { + result = rewriteErrorRule((ErrorRule) rule, rewrittenConditions, conditionsChanged); + } else if (rule instanceof TreeRule) { + result = rewriteTreeRule((TreeRule) rule, rewrittenConditions, conditionsChanged); + } else if (conditionsChanged) { + throw new UnsupportedOperationException("Cannot change rule: " + rule); + } else { + result = rule; + } + + rewrittenRules.put(rule, result); + return result; + } + + private List rewriteConditions(List conditions) { + List rewritten = new ArrayList<>(conditions.size()); + for (Condition condition : conditions) { + rewritten.add(rewriteCondition(condition)); + } + return rewritten; + } + + private Rule rewriteEndpointRule( + EndpointRule rule, + List rewrittenConditions, + boolean conditionsChanged + ) { + Endpoint rewrittenEndpoint = referenceRewriter.rewriteEndpoint(rule.getEndpoint()); + + if (conditionsChanged || rewrittenEndpoint != rule.getEndpoint()) { + return EndpointRule.builder() + .description(rule.getDocumentation().orElse(null)) + .conditions(rewrittenConditions) + .endpoint(rewrittenEndpoint); + } + + return rule; + } + + private Rule rewriteErrorRule(ErrorRule rule, List rewrittenConditions, boolean conditionsChanged) { + Expression rewrittenError = referenceRewriter.rewrite(rule.getError()); + + if (conditionsChanged || rewrittenError != rule.getError()) { + return ErrorRule.builder() + .description(rule.getDocumentation().orElse(null)) + .conditions(rewrittenConditions) + .error(rewrittenError); + } + + return rule; + } + + private Rule rewriteTreeRule(TreeRule rule, List rewrittenConditions, boolean conditionsChanged) { + List rewrittenNestedRules = new ArrayList<>(); + boolean nestedChanged = false; + + for (Rule nestedRule : rule.getRules()) { + enterScope(); + Rule rewritten = rewriteRule(nestedRule); + rewrittenNestedRules.add(rewritten); + if (rewritten != nestedRule) { + nestedChanged = true; + } + exitScope(); + } + + if (conditionsChanged || nestedChanged) { + return TreeRule.builder() + .description(rule.getDocumentation().orElse(null)) + .conditions(rewrittenConditions) + .treeRule(rewrittenNestedRules); + } + + return rule; + } + + private String resolveReference(String originalName) { + // Input parameters are never rewritten + return variableAnalysis.getInputParams().contains(originalName) + ? originalName + : scopeStack.peek().getOrDefault(originalName, originalName); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/TreeRewriter.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/TreeRewriter.java new file mode 100644 index 00000000000..7690a1cffd9 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/TreeRewriter.java @@ -0,0 +1,292 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Predicate; +import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.rulesengine.language.Endpoint; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionNode; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.RecordLiteral; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.TupleLiteral; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; + +/** + * Utility for rewriting references within expression trees. + */ +final class TreeRewriter { + // A no-op rewriter that returns expressions unchanged. + static final TreeRewriter IDENTITY = new TreeRewriter(ref -> ref, expr -> false); + + private final Function referenceTransformer; + private final Predicate shouldRewrite; + + /** + * Creates a new reference rewriter. + * + * @param referenceTransformer function to transform references + * @param shouldRewrite predicate to determine if an expression needs rewriting + */ + TreeRewriter( + Function referenceTransformer, + Predicate shouldRewrite + ) { + this.referenceTransformer = referenceTransformer; + this.shouldRewrite = shouldRewrite; + } + + /** + * Creates a simple rewriter that replaces specific references. + * + * @param replacements map of variable names to replacement expressions + * @return a reference rewriter that performs the replacements + */ + static TreeRewriter forReplacements(Map replacements) { + if (replacements.isEmpty()) { + return IDENTITY; + } + return new TreeRewriter( + ref -> replacements.getOrDefault(ref.getName().toString(), ref), + expr -> expr.getReferences().stream().anyMatch(replacements::containsKey)); + } + + static List transformNestedRules( + TreeRule tree, + String parentPath, + BiFunction transformer + ) { + List result = new ArrayList<>(); + for (int i = 0; i < tree.getRules().size(); i++) { + Rule transformed = transformer.apply( + tree.getRules().get(i), + parentPath + "/tree/rule[" + i + "]"); + if (transformed != null) { + result.add(transformed); + } + } + return result; + } + + /** + * Rewrites references within an expression tree. + * + * @param expression the expression to rewrite + * @return the rewritten expression, or the original if no changes needed + */ + Expression rewrite(Expression expression) { + if (!shouldRewrite.test(expression)) { + return expression; + } + + if (expression instanceof StringLiteral) { + return rewriteStringLiteral((StringLiteral) expression); + } else if (expression instanceof TupleLiteral) { + return rewriteTupleLiteral((TupleLiteral) expression); + } else if (expression instanceof RecordLiteral) { + return rewriteRecordLiteral((RecordLiteral) expression); + } else if (expression instanceof Reference) { + return referenceTransformer.apply((Reference) expression); + } else if (expression instanceof LibraryFunction) { + return rewriteLibraryFunction((LibraryFunction) expression); + } + + return expression; + } + + Map> rewriteHeaders(Map> headers) { + if (headers.isEmpty()) { + return headers; + } + + Map> rewritten = null; + boolean changed = false; + + for (Map.Entry> entry : headers.entrySet()) { + List originalValues = entry.getValue(); + List rewrittenValues = null; + + for (int i = 0; i < originalValues.size(); i++) { + Expression original = originalValues.get(i); + Expression rewrittenExpr = rewrite(original); + + if (rewrittenExpr != original) { + if (rewrittenValues == null) { + rewrittenValues = new ArrayList<>(originalValues.subList(0, i)); + } + rewrittenValues.add(rewrittenExpr); + changed = true; + } else if (rewrittenValues != null) { + rewrittenValues.add(original); + } + } + + if (changed && rewritten == null) { + rewritten = new LinkedHashMap<>(); + // Copy all previous entries + for (Map.Entry> prev : headers.entrySet()) { + if (prev.getKey().equals(entry.getKey())) { + break; + } + rewritten.put(prev.getKey(), prev.getValue()); + } + } + + if (rewritten != null) { + rewritten.put(entry.getKey(), + rewrittenValues != null ? rewrittenValues : originalValues); + } + } + + return changed ? rewritten : headers; + } + + Map rewriteProperties(Map properties) { + if (properties.isEmpty()) { + return properties; + } + + Map rewritten = null; + boolean changed = false; + + for (Map.Entry entry : properties.entrySet()) { + Expression rewrittenExpr = rewrite(entry.getValue()); + + if (rewrittenExpr != entry.getValue()) { + if (!(rewrittenExpr instanceof Literal)) { + throw new IllegalStateException("Property value must be a literal"); + } + + if (rewritten == null) { + rewritten = new LinkedHashMap<>(); + // Copy all previous entries + for (Map.Entry prev : properties.entrySet()) { + if (prev.getKey().equals(entry.getKey())) { + break; + } + rewritten.put(prev.getKey(), prev.getValue()); + } + } + + rewritten.put(entry.getKey(), (Literal) rewrittenExpr); + changed = true; + } else if (rewritten != null) { + rewritten.put(entry.getKey(), entry.getValue()); + } + } + + return changed ? rewritten : properties; + } + + Endpoint rewriteEndpoint(Endpoint endpoint) { + Expression rewrittenUrl = rewrite(endpoint.getUrl()); + Map> rewrittenHeaders = rewriteHeaders(endpoint.getHeaders()); + Map rewrittenProperties = rewriteProperties(endpoint.getProperties()); + + // Only create new endpoint if something changed + if (rewrittenUrl != endpoint.getUrl() + || rewrittenHeaders != endpoint.getHeaders() + || rewrittenProperties != endpoint.getProperties()) { + return Endpoint.builder() + .url(rewrittenUrl) + .headers(rewrittenHeaders) + .properties(rewrittenProperties) + .build(); + } + return endpoint; + } + + private Expression rewriteStringLiteral(StringLiteral str) { + Template template = str.value(); + if (template.isStatic()) { + return str; + } + + StringBuilder templateBuilder = new StringBuilder(); + boolean changed = false; + + for (Template.Part part : template.getParts()) { + if (part instanceof Template.Dynamic) { + Template.Dynamic dynamic = (Template.Dynamic) part; + Expression original = dynamic.toExpression(); + Expression rewritten = rewrite(original); + if (rewritten != original) { + changed = true; + } + templateBuilder.append('{').append(rewritten).append('}'); + } else { + templateBuilder.append(((Template.Literal) part).getValue()); + } + } + + return changed ? Literal.stringLiteral(Template.fromString(templateBuilder.toString())) : str; + } + + private Expression rewriteTupleLiteral(TupleLiteral tuple) { + List rewrittenMembers = new ArrayList<>(); + boolean changed = false; + + for (Literal member : tuple.members()) { + Literal rewritten = (Literal) rewrite(member); + rewrittenMembers.add(rewritten); + if (rewritten != member) { + changed = true; + } + } + + return changed ? Literal.tupleLiteral(rewrittenMembers) : tuple; + } + + private Expression rewriteRecordLiteral(RecordLiteral record) { + Map rewrittenMembers = new LinkedHashMap<>(); + boolean changed = false; + + for (Map.Entry entry : record.members().entrySet()) { + Literal original = entry.getValue(); + Literal rewritten = (Literal) rewrite(original); + rewrittenMembers.put(entry.getKey(), rewritten); + if (rewritten != original) { + changed = true; + } + } + + return changed ? Literal.recordLiteral(rewrittenMembers) : record; + } + + private Expression rewriteLibraryFunction(LibraryFunction fn) { + List rewrittenArgs = new ArrayList<>(); + boolean changed = false; + + for (Expression arg : fn.getArguments()) { + Expression rewritten = rewrite(arg); + rewrittenArgs.add(rewritten); + if (rewritten != arg) { + changed = true; + } + } + + if (!changed) { + return fn; + } + + FunctionNode node = FunctionNode.builder() + .name(Node.from(fn.getName())) + .arguments(rewrittenArgs) + .build(); + return fn.getFunctionDefinition().createFunction(node); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java new file mode 100644 index 00000000000..7d477a72b51 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java @@ -0,0 +1,220 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import software.amazon.smithy.rulesengine.language.Endpoint; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.RecordLiteral; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.TupleLiteral; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; + +/** + * Analyzes variables in an endpoint rule set, collecting bindings, reference counts, + * and expression mappings needed for SSA transformation. + */ +final class VariableAnalysis { + private final Set inputParams; + private final Map> bindings; + private final Map referenceCounts; + private final Map> expressionMappings; + + private VariableAnalysis( + Set inputParams, + Map> bindings, + Map referenceCounts, + Map> expressionMappings + ) { + this.inputParams = inputParams; + this.bindings = bindings; + this.referenceCounts = referenceCounts; + this.expressionMappings = expressionMappings; + } + + static VariableAnalysis analyze(EndpointRuleSet ruleSet) { + Set inputParameters = extractInputParameters(ruleSet); + + AnalysisVisitor visitor = new AnalysisVisitor(inputParameters); + for (Rule rule : ruleSet.getRules()) { + visitor.visitRule(rule); + } + + return new VariableAnalysis( + inputParameters, + visitor.bindings, + visitor.referenceCounts, + createExpressionMappings(visitor.bindings)); + } + + Set getInputParams() { + return inputParams; + } + + Map> getExpressionMappings() { + return expressionMappings; + } + + int getReferenceCount(String variableName) { + return referenceCounts.getOrDefault(variableName, 0); + } + + boolean isReferencedOnce(String variableName) { + return getReferenceCount(variableName) == 1; + } + + boolean hasSingleBinding(String variableName) { + Set expressions = bindings.get(variableName); + return expressions != null && expressions.size() == 1; + } + + boolean hasMultipleBindings(String variableName) { + Set expressions = bindings.get(variableName); + return expressions != null && expressions.size() > 1; + } + + boolean isSafeToInline(String variableName) { + return hasSingleBinding(variableName) && isReferencedOnce(variableName); + } + + private static Set extractInputParameters(EndpointRuleSet ruleSet) { + Set inputParameters = new HashSet<>(); + for (Parameter param : ruleSet.getParameters()) { + inputParameters.add(param.getName().toString()); + } + return inputParameters; + } + + private static Map> createExpressionMappings( + Map> bindings + ) { + Map> result = new HashMap<>(); + for (Map.Entry> entry : bindings.entrySet()) { + String varName = entry.getKey(); + Set expressions = entry.getValue(); + result.put(varName, createMappingForVariable(varName, expressions)); + } + return result; + } + + private static Map createMappingForVariable( + String varName, + Set expressions + ) { + Map mapping = new HashMap<>(); + + if (expressions.size() == 1) { + // Single binding: no SSA rename needed + String expression = expressions.iterator().next(); + mapping.put(expression, varName); + } else { + // Multiple bindings: use SSA naming convention + List sortedExpressions = new ArrayList<>(expressions); + sortedExpressions.sort(String::compareTo); + for (int i = 0; i < sortedExpressions.size(); i++) { + String expression = sortedExpressions.get(i); + String uniqueName = varName + "_ssa_" + (i + 1); + mapping.put(expression, uniqueName); + } + } + + return mapping; + } + + private static class AnalysisVisitor { + final Map> bindings = new HashMap<>(); + final Map referenceCounts = new HashMap<>(); + private final Set inputParams; + + AnalysisVisitor(Set inputParams) { + this.inputParams = inputParams; + } + + void visitRule(Rule rule) { + for (Condition condition : rule.getConditions()) { + if (condition.getResult().isPresent()) { + String varName = condition.getResult().get().toString(); + LibraryFunction fn = condition.getFunction(); + String expression = fn.toString(); + + bindings.computeIfAbsent(varName, k -> new HashSet<>()) + .add(expression); + } + + countReferences(condition.getFunction()); + } + + if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + for (Rule nestedRule : treeRule.getRules()) { + visitRule(nestedRule); + } + } else if (rule instanceof EndpointRule) { + EndpointRule endpointRule = (EndpointRule) rule; + Endpoint endpoint = endpointRule.getEndpoint(); + countReferences(endpoint.getUrl()); + endpoint.getHeaders() + .values() + .stream() + .flatMap(List::stream) + .forEach(this::countReferences); + endpoint.getProperties() + .values() + .forEach(this::countReferences); + } else if (rule instanceof ErrorRule) { + countReferences(((ErrorRule) rule).getError()); + } + } + + private void countReferences(Expression expression) { + if (expression instanceof Reference) { + Reference ref = (Reference) expression; + String name = ref.getName().toString(); + referenceCounts.merge(name, 1, Integer::sum); + } else if (expression instanceof StringLiteral) { + StringLiteral str = (StringLiteral) expression; + Template template = str.value(); + if (!template.isStatic()) { + for (Template.Part part : template.getParts()) { + if (part instanceof Template.Dynamic) { + Template.Dynamic dynamic = (Template.Dynamic) part; + countReferences(dynamic.toExpression()); + } + } + } + } else if (expression instanceof LibraryFunction) { + LibraryFunction fn = (LibraryFunction) expression; + for (Expression arg : fn.getArguments()) { + countReferences(arg); + } + } else if (expression instanceof TupleLiteral) { + TupleLiteral tuple = (TupleLiteral) expression; + for (Literal member : tuple.members()) { + countReferences(member); + } + } else if (expression instanceof RecordLiteral) { + RecordLiteral record = (RecordLiteral) expression; + for (Literal value : record.members().values()) { + countReferences(value); + } + } + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java new file mode 100644 index 00000000000..b0846e786b3 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java @@ -0,0 +1,285 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.logging.Logger; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; + +/** + * Consolidates variable names for identical expressions and eliminates redundant bindings. + * + *

This transform identifies conditions that compute the same expression but assign + * the result to different variable names, and either consolidates them to use the same + * name or eliminates redundant bindings when the same expression is already bound in + * an ancestor scope. + */ +final class VariableConsolidationTransform { + private static final Logger LOGGER = Logger.getLogger(VariableConsolidationTransform.class.getName()); + + // Global map of canonical expressions to their first variable name seen + private final Map globalExpressionToVar = new HashMap<>(); + + // Maps old variable names to new canonical names for rewriting references + private final Map variableRenameMap = new HashMap<>(); + + // Tracks conditions to eliminate (by their path in the tree) + private final Set conditionsToEliminate = new HashSet<>(); + + // Tracks all variables defined at each scope level to check for conflicts + private final Map> scopeDefinedVars = new HashMap<>(); + + private int consolidatedCount = 0; + private int eliminatedCount = 0; + private int skippedDueToShadowing = 0; + + public static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + VariableConsolidationTransform transform = new VariableConsolidationTransform(); + return transform.consolidate(ruleSet); + } + + private EndpointRuleSet consolidate(EndpointRuleSet ruleSet) { + LOGGER.info("Starting variable consolidation transform"); + + for (int i = 0; i < ruleSet.getRules().size(); i++) { + collectDefinitions(ruleSet.getRules().get(i), "rule[" + i + "]"); + } + + for (int i = 0; i < ruleSet.getRules().size(); i++) { + discoverBindingsInRule(ruleSet.getRules().get(i), "rule[" + i + "]", new HashMap<>(), new HashSet<>()); + } + + List transformedRules = new ArrayList<>(); + for (int i = 0; i < ruleSet.getRules().size(); i++) { + transformedRules.add(transformRule(ruleSet.getRules().get(i), "rule[" + i + "]")); + } + + LOGGER.info(String.format("Variable consolidation: %d consolidated, %d eliminated, %d skipped due to shadowing", + consolidatedCount, + eliminatedCount, + skippedDueToShadowing)); + + return EndpointRuleSet.builder() + .parameters(ruleSet.getParameters()) + .rules(transformedRules) + .version(ruleSet.getVersion()) + .build(); + } + + private void collectDefinitions(Rule rule, String path) { + Set definedVars = new HashSet<>(); + + // Collect all variables defined at this scope level + for (Condition condition : rule.getConditions()) { + if (condition.getResult().isPresent()) { + definedVars.add(condition.getResult().get().toString()); + } + } + + scopeDefinedVars.put(path, definedVars); + + if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + for (int i = 0; i < treeRule.getRules().size(); i++) { + collectDefinitions(treeRule.getRules().get(i), path + "/tree/rule[" + i + "]"); + } + } + } + + private void discoverBindingsInRule( + Rule rule, + String path, + Map parentBindings, + Set ancestorVars + ) { + // Track bindings at current scope (inherits parent bindings) + Map currentBindings = new HashMap<>(parentBindings); + // Track all variables visible from ancestors (for shadowing check) + Set visibleAncestorVars = new HashSet<>(ancestorVars); + + for (int i = 0; i < rule.getConditions().size(); i++) { + Condition condition = rule.getConditions().get(i); + String condPath = path + "/cond[" + i + "]"; + + if (condition.getResult().isPresent()) { + String varName = condition.getResult().get().toString(); + LibraryFunction fn = condition.getFunction(); + String canonical = fn.canonicalize().toString(); + + // Check if this expression is already bound in parent scope + String parentVar = parentBindings.get(canonical); + if (parentVar != null) { + // Found duplicate in parent, eliminate this binding + variableRenameMap.put(varName, parentVar); + conditionsToEliminate.add(condPath); + eliminatedCount++; + LOGGER.info(String.format("Eliminating redundant binding at %s: '%s' -> '%s' for: %s", + condPath, + varName, + parentVar, + canonical)); + } else { + // Not bound in parent, add to current scope + currentBindings.put(canonical, varName); + visibleAncestorVars.add(varName); + + // Check for global consolidation opportunity + String globalVar = globalExpressionToVar.get(canonical); + if (globalVar != null && !globalVar.equals(varName)) { + // Same expression elsewhere with different name + // Check if consolidation would cause shadowing + if (!wouldCauseShadowing(globalVar, path, ancestorVars)) { + variableRenameMap.put(varName, globalVar); + consolidatedCount++; + LOGGER.info(String.format("Consolidating '%s' -> '%s' for: %s", + varName, + globalVar, + canonical)); + } else { + skippedDueToShadowing++; + LOGGER.fine(String.format("Cannot consolidate '%s' -> '%s' (would shadow) for: %s", + varName, + globalVar, + canonical)); + } + } else if (globalVar == null) { + // First time seeing this expression globally + globalExpressionToVar.put(canonical, varName); + } + } + } + } + + if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + for (int i = 0; i < treeRule.getRules().size(); i++) { + discoverBindingsInRule( + treeRule.getRules().get(i), + path + "/tree/rule[" + i + "]", + currentBindings, + visibleAncestorVars); + } + } + } + + private boolean wouldCauseShadowing(String varName, String currentPath, Set ancestorVars) { + // Check if using this variable name would shadow an ancestor variable + if (ancestorVars.contains(varName)) { + return true; + } + + // Check if any child scope already defines this variable + // (which would be shadowed if we use it here) + for (Map.Entry> entry : scopeDefinedVars.entrySet()) { + String scopePath = entry.getKey(); + Set scopeVars = entry.getValue(); + // Check if this scope is a descendant of current path + if (scopePath.startsWith(currentPath + "/") && scopeVars.contains(varName)) { + return true; + } + } + + return false; + } + + private Rule transformRule(Rule rule, String path) { + List transformedConditions = new ArrayList<>(); + + for (int i = 0; i < rule.getConditions().size(); i++) { + String condPath = path + "/cond[" + i + "]"; + + if (conditionsToEliminate.contains(condPath)) { + // Skip this condition entirely since it's redundant + continue; + } + + Condition condition = rule.getConditions().get(i); + transformedConditions.add(transformCondition(condition)); + } + + if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + return TreeRule.builder() + .description(rule.getDocumentation().orElse(null)) + .conditions(transformedConditions) + .treeRule(TreeRewriter.transformNestedRules(treeRule, path, this::transformRule)); + + } else if (rule instanceof EndpointRule) { + EndpointRule endpointRule = (EndpointRule) rule; + TreeRewriter rewriter = createRewriter(); + + return EndpointRule.builder() + .description(rule.getDocumentation().orElse(null)) + .conditions(transformedConditions) + .endpoint(rewriter.rewriteEndpoint(endpointRule.getEndpoint())); + + } else if (rule instanceof ErrorRule) { + ErrorRule errorRule = (ErrorRule) rule; + TreeRewriter rewriter = createRewriter(); + + return ErrorRule.builder() + .description(rule.getDocumentation().orElse(null)) + .conditions(transformedConditions) + .error(rewriter.rewrite(errorRule.getError())); + } + + return rule.withConditions(transformedConditions); + } + + private Condition transformCondition(Condition condition) { + // Rewrite any references in the function + TreeRewriter rewriter = createRewriter(); + LibraryFunction fn = condition.getFunction(); + LibraryFunction rewrittenFn = (LibraryFunction) rewriter.rewrite(fn); + + // If this condition assigns to a variable that should be renamed, + // use the canonical name instead + if (condition.getResult().isPresent()) { + String varName = condition.getResult().get().toString(); + String canonicalName = variableRenameMap.get(varName); + + if (canonicalName != null) { + // This variable is being consolidated, use the canonical name + return Condition.builder() + .fn(rewrittenFn) + .result(Identifier.of(canonicalName)) + .build(); + } + } + + // No consolidation needed, but may still need reference rewriting + if (rewrittenFn != fn) { + return condition.toBuilder().fn(rewrittenFn).build(); + } + + return condition; + } + + private TreeRewriter createRewriter() { + if (variableRenameMap.isEmpty()) { + return TreeRewriter.IDENTITY; + } + + Map replacements = new HashMap<>(); + for (Map.Entry entry : variableRenameMap.entrySet()) { + replacements.put(entry.getKey(), Expression.getReference(Identifier.of(entry.getValue()))); + } + + return TreeRewriter.forReplacements(replacements); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java new file mode 100644 index 00000000000..9cd9d627c36 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java @@ -0,0 +1,369 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.traits; + +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.Set; +import java.util.function.Function; +import software.amazon.smithy.model.node.ArrayNode; +import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.model.node.ObjectNode; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.model.traits.AbstractTrait; +import software.amazon.smithy.model.traits.AbstractTraitBuilder; +import software.amazon.smithy.model.traits.Trait; +import software.amazon.smithy.rulesengine.language.RulesVersion; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.bdd.Bdd; +import software.amazon.smithy.rulesengine.logic.bdd.BddCompiler; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.utils.SetUtils; +import software.amazon.smithy.utils.SmithyBuilder; +import software.amazon.smithy.utils.ToSmithyBuilder; + +/** + * Trait containing a precompiled BDD with full context for endpoint resolution. + */ +public final class EndpointBddTrait extends AbstractTrait implements ToSmithyBuilder { + public static final ShapeId ID = ShapeId.from("smithy.rules#endpointBdd"); + + private static final RulesVersion MIN_VERSION = RulesVersion.V1_1; + private static final Set ALLOWED_PROPERTIES = SetUtils.of( + "version", + "parameters", + "conditions", + "results", + "root", + "nodes", + "nodeCount"); + + private final RulesVersion version; + private final Parameters parameters; + private final List conditions; + private final List results; + private final Bdd bdd; + + private EndpointBddTrait(Builder builder) { + super(ID, builder.getSourceLocation()); + this.version = SmithyBuilder.requiredState("version", builder.version); + this.parameters = SmithyBuilder.requiredState("parameters", builder.parameters); + this.conditions = SmithyBuilder.requiredState("conditions", builder.conditions); + this.results = SmithyBuilder.requiredState("results", builder.results); + this.bdd = SmithyBuilder.requiredState("bdd", builder.bdd); + + if (version.compareTo(MIN_VERSION) < 0) { + throw new IllegalArgumentException("Rules engine version for endpointBdd trait must be >= " + MIN_VERSION); + } + } + + /** + * Creates a BddTrait from a control flow graph. + * + * @param cfg the control flow graph to compile + * @return the BddTrait containing the compiled BDD and all context + */ + public static EndpointBddTrait from(Cfg cfg) { + BddCompiler compiler = new BddCompiler(cfg); + Bdd bdd = compiler.compile(); + + if (compiler.getOrderedConditions().size() != bdd.getConditionCount()) { + throw new IllegalStateException("Mismatch between BDD var count and orderedConditions size"); + } + + // Automatically convert 1.0 versions of the decision tree to 1.1 for the minimum version of the BDD trait. + RulesVersion version = cfg.getVersion(); + if (version.equals(RulesVersion.V1_0)) { + version = RulesVersion.V1_1; + } + + return builder() + .version(version) + .parameters(cfg.getParameters()) + .conditions(compiler.getOrderedConditions()) + .results(compiler.getIndexedResults()) + .bdd(bdd) + .build(); + } + + /** + * Gets the parameters for the endpoint rules. + * + * @return the parameters + */ + public Parameters getParameters() { + return parameters; + } + + /** + * Gets the ordered list of conditions. + * + * @return the conditions in evaluation order + */ + public List getConditions() { + return conditions; + } + + /** + * Gets the ordered list of results. + * + * @return the results (index 0 is always NoMatchRule) + */ + public List getResults() { + return results; + } + + /** + * Gets the BDD structure. + * + * @return the BDD + */ + public Bdd getBdd() { + return bdd; + } + + /** + * Get the endpoint ruleset version. + * + * @return the rules engine version + */ + public RulesVersion getVersion() { + return version; + } + + /** + * Transform this BDD using the given function and return the updated BddTrait. + * + * @param transformer Transformer used to modify the trait. + * @return the updated trait. + */ + public EndpointBddTrait transform(Function transformer) { + return transformer.apply(this); + } + + @Override + protected Node createNode() { + ObjectNode.Builder builder = ObjectNode.builder(); + builder.withMember("version", version.toString()); + builder.withMember("parameters", parameters.toNode()); + + ArrayNode.Builder conditionBuilder = ArrayNode.builder(); + for (Condition c : conditions) { + conditionBuilder.withValue(c); + } + builder.withMember("conditions", conditionBuilder.build()); + + // Results (skip NoMatchRule at index 0 for serialization) + ArrayNode.Builder resultBuilder = ArrayNode.builder(); + if (!results.isEmpty() && !(results.get(0) instanceof NoMatchRule)) { + throw new IllegalStateException("BDD must always have a NoMatchRule as the first result"); + } + for (int i = 1; i < results.size(); i++) { + Rule result = results.get(i); + if (result instanceof NoMatchRule) { + throw new IllegalStateException("NoMatch rules can only appear at rule index 0. Found at index " + i); + } else if (result == null) { + throw new IllegalStateException("BDD result is null at index " + i); + } + resultBuilder.withValue(result); + } + builder.withMember("results", resultBuilder.build()); + + builder.withMember("root", bdd.getRootRef()); + builder.withMember("nodeCount", bdd.getNodeCount()); + builder.withMember("nodes", encodeNodes(bdd)); + + return builder.build(); + } + + /** + * Creates a BddTrait from a Node representation. + * + * @param node the node to parse + * @return the BddTrait + */ + public static EndpointBddTrait fromNode(Node node) { + ObjectNode obj = node.expectObjectNode(); + obj.warnIfAdditionalProperties(ALLOWED_PROPERTIES); + RulesVersion version = RulesVersion.of(obj.expectStringMember("version").getValue()); + Parameters params = Parameters.fromNode(obj.expectObjectMember("parameters")); + List conditions = obj.expectArrayMember("conditions").getElementsAs(Condition::fromNode); + + List serializedResults = obj.expectArrayMember("results").getElementsAs(Rule::fromNode); + List results = new ArrayList<>(); + results.add(NoMatchRule.INSTANCE); // Always add no-match at index 0 + results.addAll(serializedResults); + + String nodesBase64 = obj.expectStringMember("nodes").getValue(); + int nodeCount = obj.expectNumberMember("nodeCount").getValue().intValue(); + int rootRef = obj.expectNumberMember("root").getValue().intValue(); + + Bdd bdd = decodeBdd(nodesBase64, nodeCount, rootRef, conditions.size(), results.size()); + + EndpointBddTrait trait = builder() + .version(version) + .sourceLocation(node) + .parameters(params) + .conditions(conditions) + .results(results) + .bdd(bdd) + .build(); + trait.setNodeCache(node); + return trait; + } + + private static String encodeNodes(Bdd bdd) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos)) { + bdd.getNodes((varIdx, high, low) -> { + try { + dos.writeInt(varIdx); + dos.writeInt(high); + dos.writeInt(low); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }); + dos.flush(); + return Base64.getEncoder().encodeToString(baos.toByteArray()); + } catch (IOException e) { + throw new RuntimeException("Failed to encode BDD nodes", e); + } catch (UncheckedIOException e) { + throw new RuntimeException("Failed to encode BDD nodes", e.getCause()); + } + } + + private static Bdd decodeBdd(String base64, int nodeCount, int rootRef, int conditionCount, int resultCount) { + byte[] data = Base64.getDecoder().decode(base64); + if (data.length != nodeCount * 12) { + throw new IllegalArgumentException("Expected " + (nodeCount * 12) + " bytes for " + nodeCount + + " nodes, but got " + data.length); + } + + int[] nodes = new int[nodeCount * 3]; + ByteBuffer buffer = ByteBuffer.wrap(data).order(ByteOrder.BIG_ENDIAN); + for (int i = 0; i < nodes.length; i++) { + nodes[i] = buffer.getInt(); + } + + return new Bdd(rootRef, conditionCount, resultCount, nodeCount, nodes); + } + + /** + * Creates a new builder for BddTrait. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + @Override + public Builder toBuilder() { + return builder() + .version(version) + .sourceLocation(getSourceLocation()) + .parameters(parameters) + .conditions(conditions) + .results(results) + .bdd(bdd); + } + + /** + * Builder for BddTrait. + */ + public static final class Builder extends AbstractTraitBuilder { + private RulesVersion version = RulesVersion.V1_1; + private Parameters parameters; + private List conditions; + private List results; + private Bdd bdd; + + private Builder() {} + + /** + * Sets the rules engine version. + * + * @param version Version to set (e.g., 1.1). + * @return this builder + */ + public Builder version(RulesVersion version) { + this.version = version; + return this; + } + + /** + * Sets the parameters. + * + * @param parameters the parameters + * @return this builder + */ + public Builder parameters(Parameters parameters) { + this.parameters = parameters; + return this; + } + + /** + * Sets the conditions. + * + * @param conditions the conditions in evaluation order + * @return this builder + */ + public Builder conditions(List conditions) { + this.conditions = conditions; + return this; + } + + /** + * Sets the results. + * + * @param results the results (must have NoMatchRule at index 0) + * @return this builder + */ + public Builder results(List results) { + this.results = results; + return this; + } + + /** + * Sets the BDD structure. + * + * @param bdd the BDD + * @return this builder + */ + public Builder bdd(Bdd bdd) { + this.bdd = bdd; + return this; + } + + @Override + public EndpointBddTrait build() { + return new EndpointBddTrait(this); + } + } + + public static final class Provider extends AbstractTrait.Provider { + public Provider() { + super(ID); + } + + @Override + public Trait createTrait(ShapeId target, Node value) { + EndpointBddTrait trait = EndpointBddTrait.fromNode(value); + trait.setNodeCache(value); + return trait; + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/BddTraitValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/BddTraitValidator.java new file mode 100644 index 00000000000..eb5c4ed76d7 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/BddTraitValidator.java @@ -0,0 +1,120 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.validators; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.validation.AbstractValidator; +import software.amazon.smithy.model.validation.ValidationEvent; +import software.amazon.smithy.rulesengine.logic.bdd.Bdd; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; + +public final class BddTraitValidator extends AbstractValidator { + @Override + public List validate(Model model) { + if (!model.isTraitApplied(EndpointBddTrait.class)) { + return Collections.emptyList(); + } + + List events = new ArrayList<>(); + for (ServiceShape service : model.getServiceShapesWithTrait(EndpointBddTrait.class)) { + validateService(events, service, service.expectTrait(EndpointBddTrait.class)); + } + + return events; + } + + private void validateService(List events, ServiceShape service, EndpointBddTrait trait) { + Bdd bdd = trait.getBdd(); + + // Validate root reference + int rootRef = bdd.getRootRef(); + if (Bdd.isComplemented(rootRef) && rootRef != -1) { + events.add(error(service, trait, "Root reference cannot be complemented: " + rootRef)); + } + validateReference(events, service, trait, "Root", rootRef, bdd, trait); + + // Validate that condition and result counts match what's in the trait + if (bdd.getConditionCount() != trait.getConditions().size()) { + events.add(error(service, + trait, + String.format("BDD condition count (%d) doesn't match trait conditions (%d)", + bdd.getConditionCount(), + trait.getConditions().size()))); + } + + if (bdd.getResultCount() != trait.getResults().size()) { + events.add(error(service, + trait, + String.format("BDD result count (%d) doesn't match trait results (%d)", + bdd.getResultCount(), + trait.getResults().size()))); + } + + // Validate nodes + int nodeCount = bdd.getNodeCount(); + + for (int i = 1; i < nodeCount; i++) { + int varIdx = bdd.getVariable(i); + int highRef = bdd.getHigh(i); + int lowRef = bdd.getLow(i); + + if (varIdx < 0 || varIdx >= bdd.getConditionCount()) { + events.add(error(service, + trait, + String.format( + "Node %d has invalid variable index %d (condition count: %d)", + i, + varIdx, + bdd.getConditionCount()))); + } + + validateReference(events, service, trait, String.format("Node %d high", i), highRef, bdd, trait); + validateReference(events, service, trait, String.format("Node %d low", i), lowRef, bdd, trait); + } + } + + private void validateReference( + List events, + ServiceShape service, + EndpointBddTrait trait, + String context, + int ref, + Bdd bdd, + EndpointBddTrait bddTrait + ) { + if (ref == 0) { + events.add(error(service, trait, String.format("%s has invalid reference: 0", context))); + } else if (Bdd.isNodeReference(ref)) { + int nodeIndex = Math.abs(ref) - 1; + int nodeCount = bdd.getNodeCount(); + if (nodeIndex >= nodeCount) { + events.add(error(service, + trait, + String.format( + "%s reference %d points to non-existent node %d (node count: %d)", + context, + ref, + nodeIndex, + nodeCount))); + } + } else if (Bdd.isResultReference(ref)) { + int resultIndex = ref - Bdd.RESULT_OFFSET; + if (resultIndex >= bddTrait.getResults().size()) { + events.add(error(service, + trait, + String.format( + "%s reference %d points to non-existent result %d (result count: %d)", + context, + ref, + resultIndex, + bddTrait.getResults().size()))); + } + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointTestsTraitValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/EndpointTestsTraitValidator.java similarity index 64% rename from smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointTestsTraitValidator.java rename to smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/EndpointTestsTraitValidator.java index bb8fd6f37f5..c2d3f0de3b2 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointTestsTraitValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/EndpointTestsTraitValidator.java @@ -2,7 +2,7 @@ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ -package software.amazon.smithy.rulesengine.traits; +package software.amazon.smithy.rulesengine.validators; import java.util.ArrayList; import java.util.HashMap; @@ -18,8 +18,13 @@ import software.amazon.smithy.model.validation.NodeValidationVisitor; import software.amazon.smithy.model.validation.Severity; import software.amazon.smithy.model.validation.ValidationEvent; -import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; +import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; +import software.amazon.smithy.rulesengine.traits.EndpointTestCase; +import software.amazon.smithy.rulesengine.traits.EndpointTestOperationInput; +import software.amazon.smithy.rulesengine.traits.EndpointTestsTrait; import software.amazon.smithy.utils.SmithyUnstableApi; /** @@ -39,61 +44,81 @@ public List validate(Model model) { operationNameMap.put(operationShape.getId().getName(), operationShape); } - // Precompute the built-ins and their default states, as this will - // be used frequently in downstream validation. - List builtInParamsWithDefaults = new ArrayList<>(); - List builtInParamsWithoutDefaults = new ArrayList<>(); - EndpointRuleSet ruleSet = serviceShape.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet(); - for (Parameter parameter : ruleSet.getParameters()) { - if (parameter.isBuiltIn()) { - if (parameter.getDefault().isPresent()) { - builtInParamsWithDefaults.add(parameter); - } else { - builtInParamsWithoutDefaults.add(parameter); - } + serviceShape.getTrait(EndpointRuleSetTrait.class).ifPresent(trait -> { + validateEndpointRuleSet(events, + model, + serviceShape, + trait.getEndpointRuleSet().getParameters(), + operationNameMap); + }); + + serviceShape.getTrait(EndpointBddTrait.class).ifPresent(trait -> { + validateEndpointRuleSet(events, model, serviceShape, trait.getParameters(), operationNameMap); + }); + } + + return events; + } + + private void validateEndpointRuleSet( + List events, + Model model, + ServiceShape serviceShape, + Parameters parameters, + Map operationNameMap + ) { + // Precompute the built-ins and their default states, as this will + // be used frequently in downstream validation. + List builtInParamsWithDefaults = new ArrayList<>(); + List builtInParamsWithoutDefaults = new ArrayList<>(); + + for (Parameter parameter : parameters) { + if (parameter.isBuiltIn()) { + if (parameter.getDefault().isPresent()) { + builtInParamsWithDefaults.add(parameter); + } else { + builtInParamsWithoutDefaults.add(parameter); } } + } - for (EndpointTestCase testCase : serviceShape.expectTrait(EndpointTestsTrait.class).getTestCases()) { - // If values for built-in parameters don't match the default, they MUST - // be specified in the operation inputs. Precompute the ones that don't match - // and capture their value. - Map builtInParamsWithNonDefaultValues = - getBuiltInParamsWithNonDefaultValues(builtInParamsWithDefaults, testCase); - - for (EndpointTestOperationInput testOperationInput : testCase.getOperationInputs()) { - String operationName = testOperationInput.getOperationName(); - - // It's possible for an operation defined to not be in the service closure. - if (!operationNameMap.containsKey(operationName)) { - events.add(error(serviceShape, - testOperationInput, - String.format("Test case operation `%s` does not exist in service `%s`", - operationName, - serviceShape.getId()))); - continue; - } - - // Still emit events if the operation exists, but was just not bound. - validateConfiguredBuiltInValues(serviceShape, - builtInParamsWithNonDefaultValues, - testOperationInput, - events); - validateBuiltInsWithoutDefaultsHaveValues(serviceShape, - builtInParamsWithoutDefaults, - testCase, - testOperationInput, - events); + for (EndpointTestCase testCase : serviceShape.expectTrait(EndpointTestsTrait.class).getTestCases()) { + // If values for built-in parameters don't match the default, they MUST + // be specified in the operation inputs. Precompute the ones that don't match + // and capture their value. + Map builtInParamsWithNonDefaultValues = + getBuiltInParamsWithNonDefaultValues(builtInParamsWithDefaults, testCase); + + for (EndpointTestOperationInput testOperationInput : testCase.getOperationInputs()) { + String operationName = testOperationInput.getOperationName(); - StructureShape inputShape = model.expectShape( - operationNameMap.get(operationName).getInputShape(), - StructureShape.class); - validateOperationInput(model, serviceShape, inputShape, testCase, testOperationInput, events); + // It's possible for an operation defined to not be in the service closure. + if (!operationNameMap.containsKey(operationName)) { + events.add(error(serviceShape, + testOperationInput, + String.format("Test case operation `%s` does not exist in service `%s`", + operationName, + serviceShape.getId()))); + continue; } + + // Still emit events if the operation exists, but was just not bound. + validateConfiguredBuiltInValues(serviceShape, + builtInParamsWithNonDefaultValues, + testOperationInput, + events); + validateBuiltInsWithoutDefaultsHaveValues(serviceShape, + builtInParamsWithoutDefaults, + testCase, + testOperationInput, + events); + + StructureShape inputShape = model.expectShape( + operationNameMap.get(operationName).getInputShape(), + StructureShape.class); + validateOperationInput(model, serviceShape, inputShape, testCase, testOperationInput, events); } } - - return events; } private Map getBuiltInParamsWithNonDefaultValues( diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/OperationContextParamsTraitValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/OperationContextParamsTraitValidator.java similarity index 97% rename from smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/OperationContextParamsTraitValidator.java rename to smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/OperationContextParamsTraitValidator.java index e06df8df172..59f1524591e 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/OperationContextParamsTraitValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/OperationContextParamsTraitValidator.java @@ -2,7 +2,7 @@ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ -package software.amazon.smithy.rulesengine.traits; +package software.amazon.smithy.rulesengine.validators; import java.util.ArrayList; import java.util.Collections; @@ -36,6 +36,9 @@ import software.amazon.smithy.model.validation.AbstractValidator; import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.analysis.OperationContextParamsChecker; +import software.amazon.smithy.rulesengine.traits.ContextIndex; +import software.amazon.smithy.rulesengine.traits.OperationContextParamDefinition; +import software.amazon.smithy.rulesengine.traits.OperationContextParamsTrait; import software.amazon.smithy.utils.ListUtils; import software.amazon.smithy.utils.SmithyUnstableApi; diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java index 1c678f82bd2..c79ab5fbeb8 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java @@ -8,11 +8,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; -import java.util.function.BiFunction; -import java.util.stream.Collectors; -import java.util.stream.Stream; import software.amazon.smithy.model.FromSourceLocation; import software.amazon.smithy.model.Model; import software.amazon.smithy.model.shapes.ServiceShape; @@ -20,11 +16,13 @@ import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.language.Endpoint; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; -import software.amazon.smithy.rulesengine.language.TraversingVisitor; import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; -import software.amazon.smithy.utils.ListUtils; /** * Validator which verifies an endpoint with an authSchemes property conforms to a strict schema. @@ -35,111 +33,141 @@ public final class RuleSetAuthSchemesValidator extends AbstractValidator { @Override public List validate(Model model) { List events = new ArrayList<>(); - for (ServiceShape serviceShape : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { - Validator validator = new Validator(serviceShape); - events.addAll(validator.visitRuleset( - serviceShape.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet()) - .collect(Collectors.toList())); + for (ServiceShape serviceShape : model.getServiceShapes()) { + visitRuleset(events, serviceShape, serviceShape.getTrait(EndpointRuleSetTrait.class).orElse(null)); + visitBdd(events, serviceShape, serviceShape.getTrait(EndpointBddTrait.class).orElse(null)); } return events; } - private class Validator extends TraversingVisitor { - private final ServiceShape serviceShape; - - Validator(ServiceShape serviceShape) { - this.serviceShape = serviceShape; + private void visitRuleset(List events, ServiceShape serviceShape, EndpointRuleSetTrait trait) { + if (trait != null) { + for (Rule rule : trait.getEndpointRuleSet().getRules()) { + traverse(events, serviceShape, rule); + } } + } - @Override - public Stream visitEndpoint(Endpoint endpoint) { - List events = new ArrayList<>(); - - Literal authSchemes = endpoint.getProperties().get(Identifier.of("authSchemes")); - if (authSchemes != null) { - BiFunction emitter = getEventEmitter(); - Optional> authSchemeList = authSchemes.asTupleLiteral(); - if (!authSchemeList.isPresent()) { - return Stream.of(emitter.apply(authSchemes, - String.format("Expected `authSchemes` to be a list, found: `%s`", authSchemes))); + private void visitBdd(List events, ServiceShape serviceShape, EndpointBddTrait trait) { + if (trait != null) { + for (Rule result : trait.getResults()) { + if (result instanceof EndpointRule) { + visitEndpoint(events, serviceShape, (EndpointRule) result); } + } + } + } - Set authSchemeNames = new HashSet<>(); - Set duplicateAuthSchemeNames = new HashSet<>(); - for (Literal authSchemeEntry : authSchemeList.get()) { - Optional> authSchemeMap = authSchemeEntry.asRecordLiteral(); - if (authSchemeMap.isPresent()) { - // Validate the name property so that we can also check that they're unique. - Map authScheme = authSchemeMap.get(); - Optional event = validateAuthSchemeName(authScheme, authSchemeEntry); - if (event.isPresent()) { - events.add(event.get()); - continue; - } - String schemeName = authScheme.get(NAME).asStringLiteral().get().expectLiteral(); - if (!authSchemeNames.add(schemeName)) { - duplicateAuthSchemeNames.add(schemeName); - } - - events.addAll(validateAuthScheme(schemeName, authScheme, authSchemeEntry)); - } else { - events.add(emitter.apply(authSchemes, - String.format("Expected `authSchemes` to be a list of objects, but found: `%s`", - authSchemeEntry))); - } + private void traverse(List events, ServiceShape service, Rule rule) { + if (rule instanceof EndpointRule) { + visitEndpoint(events, service, (EndpointRule) rule); + } else if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + for (Rule child : treeRule.getRules()) { + traverse(events, service, child); + } + } + } + + private void visitEndpoint(List events, ServiceShape service, EndpointRule endpointRule) { + Endpoint endpoint = endpointRule.getEndpoint(); + Literal authSchemes = endpoint.getProperties().get(Identifier.of("authSchemes")); + + if (authSchemes != null) { + List authSchemeList = authSchemes.asTupleLiteral().orElse(null); + if (authSchemeList == null) { + events.add(error(service, + authSchemes, + String.format( + "Expected `authSchemes` to be a list, found: `%s`", + authSchemes))); + return; + } + + Set authSchemeNames = new HashSet<>(); + Set duplicateAuthSchemeNames = new HashSet<>(); + for (Literal authSchemeEntry : authSchemeList) { + Map authSchemeMap = authSchemeEntry.asRecordLiteral().orElse(null); + if (authSchemeMap == null) { + events.add(error(service, + authSchemes, + String.format( + "Expected `authSchemes` to be a list of objects, but found: `%s`", + authSchemeEntry))); + continue; } - // Emit events for each duplicated auth scheme name. - for (String duplicateAuthSchemeName : duplicateAuthSchemeNames) { - events.add(emitter.apply(authSchemes, - String.format("Found duplicate `name` of `%s` in the " - + "`authSchemes` list", duplicateAuthSchemeName))); + String schemeName = validateAuthSchemeName(events, service, authSchemeMap, authSchemeEntry); + if (schemeName != null) { + if (!authSchemeNames.add(schemeName)) { + duplicateAuthSchemeNames.add(schemeName); + } + validateAuthScheme(events, service, schemeName, authSchemeMap, authSchemeEntry); } } - return events.stream(); - } - - private Optional validateAuthSchemeName( - Map authScheme, - FromSourceLocation sourceLocation - ) { - if (!authScheme.containsKey(NAME) || !authScheme.get(NAME).asStringLiteral().isPresent()) { - return Optional.of(error(serviceShape, - sourceLocation, - String.format("Expected `authSchemes` to have a `name` key with a string value but it did not: " - + "`%s`", authScheme))); + // Emit events for each duplicated auth scheme name. + for (String duplicateAuthSchemeName : duplicateAuthSchemeNames) { + events.add(error(service, + authSchemes, + String.format( + "Found duplicate `name` of `%s` in the `authSchemes` list", + duplicateAuthSchemeName))); } - return Optional.empty(); } + } - private List validateAuthScheme( - String schemeName, - Map authScheme, - FromSourceLocation sourceLocation - ) { - List events = new ArrayList<>(); + private String validateAuthSchemeName( + List events, + ServiceShape service, + Map authScheme, + FromSourceLocation sourceLocation + ) { + Literal nameLiteral = authScheme.get(NAME); + if (nameLiteral == null) { + events.add(error(service, + sourceLocation, + String.format( + "Expected `authSchemes` to have a `name` key with a string value but it did not: `%s`", + authScheme))); + return null; + } - BiFunction emitter = getEventEmitter(); + String name = nameLiteral.asStringLiteral().map(s -> s.expectLiteral()).orElse(null); + if (name == null) { + events.add(error(service, + sourceLocation, + String.format( + "Expected `authSchemes` to have a `name` key with a string value but it did not: `%s`", + authScheme))); + return null; + } - boolean validatedAuth = false; - for (AuthSchemeValidator authSchemeValidator : EndpointRuleSet.getAuthSchemeValidators()) { - if (authSchemeValidator.test(schemeName)) { - events.addAll(authSchemeValidator.validateScheme(authScheme, sourceLocation, emitter)); - validatedAuth = true; - } - } + return name; + } - if (validatedAuth) { - return events; + private void validateAuthScheme( + List events, + ServiceShape service, + String schemeName, + Map authScheme, + FromSourceLocation sourceLocation + ) { + boolean validatedAuth = false; + for (AuthSchemeValidator authSchemeValidator : EndpointRuleSet.getAuthSchemeValidators()) { + if (authSchemeValidator.test(schemeName)) { + events.addAll(authSchemeValidator.validateScheme(authScheme, + sourceLocation, + (location, message) -> error(service, location, message))); + validatedAuth = true; } - return ListUtils.of(warning(serviceShape, - String.format("Did not find a validator for the `%s` " - + "auth scheme", schemeName))); } - private BiFunction getEventEmitter() { - return (sourceLocation, message) -> error(serviceShape, sourceLocation, message); + if (!validatedAuth) { + events.add(warning(service, + String.format( + "Did not find a validator for the `%s` auth scheme", + schemeName))); } } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetBuiltInValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetBuiltInValidator.java index 430412c1d1f..9a9b7ad7713 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetBuiltInValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetBuiltInValidator.java @@ -7,7 +7,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Optional; import software.amazon.smithy.model.FromSourceLocation; import software.amazon.smithy.model.Model; import software.amazon.smithy.model.node.StringNode; @@ -16,6 +15,7 @@ import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; import software.amazon.smithy.rulesengine.traits.EndpointTestCase; import software.amazon.smithy.rulesengine.traits.EndpointTestOperationInput; @@ -28,67 +28,63 @@ public final class RuleSetBuiltInValidator extends AbstractValidator { @Override public List validate(Model model) { List events = new ArrayList<>(); - for (ServiceShape serviceShape : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { - events.addAll(validateRuleSetBuiltIns(serviceShape, - serviceShape.expectTrait(EndpointRuleSetTrait.class) - .getEndpointRuleSet())); + + for (ServiceShape s : model.getServiceShapesWithTrait(EndpointBddTrait.class)) { + validateParams(events, s, s.expectTrait(EndpointBddTrait.class).getParameters()); + } + + for (ServiceShape s : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { + validateParams(events, s, s.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet().getParameters()); } - for (ServiceShape serviceShape : model.getServiceShapesWithTrait(EndpointTestsTrait.class)) { - events.addAll(validateTestTraitBuiltIns(serviceShape, serviceShape.expectTrait(EndpointTestsTrait.class))); + for (ServiceShape s : model.getServiceShapesWithTrait(EndpointTestsTrait.class)) { + validateTestBuiltIns(events, s, s.expectTrait(EndpointTestsTrait.class)); } + return events; } - private List validateRuleSetBuiltIns(ServiceShape serviceShape, EndpointRuleSet ruleSet) { - List events = new ArrayList<>(); - for (Parameter parameter : ruleSet.getParameters()) { + private void validateParams(List events, ServiceShape service, Iterable params) { + for (Parameter parameter : params) { if (parameter.isBuiltIn()) { - validateBuiltIn(serviceShape, parameter.getBuiltIn().get(), parameter, "RuleSet") - .ifPresent(events::add); + validateBuiltIn(events, service, parameter.getBuiltIn().get(), parameter, "RuleSet"); } } - return events; } - private List validateTestTraitBuiltIns(ServiceShape serviceShape, EndpointTestsTrait testSuite) { - List events = new ArrayList<>(); + private void validateTestBuiltIns(List events, ServiceShape service, EndpointTestsTrait suite) { int testIndex = 0; - for (EndpointTestCase testCase : testSuite.getTestCases()) { + for (EndpointTestCase testCase : suite.getTestCases()) { int inputIndex = 0; for (EndpointTestOperationInput operationInput : testCase.getOperationInputs()) { for (StringNode builtInNode : operationInput.getBuiltInParams().getMembers().keySet()) { - validateBuiltIn(serviceShape, + validateBuiltIn(events, + service, builtInNode.getValue(), operationInput, "TestCase", String.valueOf(testIndex), "Inputs", - String.valueOf(inputIndex)) - .ifPresent(events::add); + String.valueOf(inputIndex)); } inputIndex++; } testIndex++; } - return events; } - private Optional validateBuiltIn( - ServiceShape serviceShape, - String builtInName, + private void validateBuiltIn( + List events, + ServiceShape service, + String name, FromSourceLocation source, String... eventIdSuffixes ) { - if (!EndpointRuleSet.hasBuiltIn(builtInName)) { - return Optional.of(error(serviceShape, - source, - String.format( - "The `%s` built-in used is not registered, valid built-ins: %s", - builtInName, - EndpointRuleSet.getKeyString()), - String.join(".", Arrays.asList(eventIdSuffixes)))); + if (!EndpointRuleSet.hasBuiltIn(name)) { + String msg = String.format("The `%s` built-in used is not registered, valid built-ins: %s", + name, + EndpointRuleSet.getKeyString()); + events.add(error(service, source, msg, String.join(".", Arrays.asList(eventIdSuffixes)))); } - return Optional.empty(); } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParamMissingDocsValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParamMissingDocsValidator.java index 5263f27e7aa..9a241219bbd 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParamMissingDocsValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParamMissingDocsValidator.java @@ -10,34 +10,43 @@ import software.amazon.smithy.model.shapes.ServiceShape; import software.amazon.smithy.model.validation.AbstractValidator; import software.amazon.smithy.model.validation.ValidationEvent; -import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; /** - * Validator to ensure that all parameters have documentation. + * Validator to ensure that all parameters have documentation (in BDD and ruleset). */ public final class RuleSetParamMissingDocsValidator extends AbstractValidator { @Override public List validate(Model model) { List events = new ArrayList<>(); - for (ServiceShape serviceShape : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { - events.addAll(validateRuleSet(serviceShape, - serviceShape.expectTrait(EndpointRuleSetTrait.class) - .getEndpointRuleSet())); + for (ServiceShape serviceShape : model.getServiceShapes()) { + visitRuleset(events, serviceShape, serviceShape.getTrait(EndpointRuleSetTrait.class).orElse(null)); + visitBdd(events, serviceShape, serviceShape.getTrait(EndpointBddTrait.class).orElse(null)); } return events; } - public List validateRuleSet(ServiceShape serviceShape, EndpointRuleSet ruleSet) { - List events = new ArrayList<>(); - for (Parameter parameter : ruleSet.getParameters()) { + private void visitRuleset(List events, ServiceShape serviceShape, EndpointRuleSetTrait trait) { + if (trait != null) { + visitParams(events, serviceShape, trait.getEndpointRuleSet().getParameters()); + } + } + + private void visitBdd(List events, ServiceShape serviceShape, EndpointBddTrait trait) { + if (trait != null) { + visitParams(events, serviceShape, trait.getParameters()); + } + } + + public void visitParams(List events, ServiceShape serviceShape, Iterable parameters) { + for (Parameter parameter : parameters) { if (!parameter.getDocumentation().isPresent()) { events.add(warning(serviceShape, parameter, String.format("Parameter `%s` does not have documentation", parameter.getName()))); } } - return events; } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParameterValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParameterValidator.java index 54d445bac93..709b84db30d 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParameterValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParameterValidator.java @@ -9,7 +9,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import software.amazon.smithy.model.FromSourceLocation; import software.amazon.smithy.model.Model; @@ -23,22 +22,21 @@ import software.amazon.smithy.model.validation.AbstractValidator; import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.analysis.OperationContextParamsChecker; -import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.evaluation.value.Value; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; -import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; import software.amazon.smithy.rulesengine.traits.ClientContextParamDefinition; import software.amazon.smithy.rulesengine.traits.ClientContextParamsTrait; import software.amazon.smithy.rulesengine.traits.ContextParamTrait; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; import software.amazon.smithy.rulesengine.traits.EndpointTestCase; import software.amazon.smithy.rulesengine.traits.EndpointTestsTrait; +import software.amazon.smithy.rulesengine.traits.OperationContextParamDefinition; import software.amazon.smithy.rulesengine.traits.OperationContextParamsTrait; import software.amazon.smithy.rulesengine.traits.StaticContextParamDefinition; import software.amazon.smithy.rulesengine.traits.StaticContextParamsTrait; import software.amazon.smithy.utils.ListUtils; -import software.amazon.smithy.utils.Pair; /** * Validator for rule-set parameters. @@ -47,45 +45,52 @@ public final class RuleSetParameterValidator extends AbstractValidator { @Override public List validate(Model model) { TopDownIndex topDownIndex = TopDownIndex.of(model); - List errors = new ArrayList<>(); - for (ServiceShape serviceShape : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { - // Pull all the parameters used in this service related to endpoints, validating that - // they are of matching types across the traits that can define them. - Pair, Map> errorsParamsPair = validateAndExtractParameters( - model, - serviceShape, - topDownIndex.getContainedOperations(serviceShape)); - errors.addAll(errorsParamsPair.getLeft()); - - // Make sure parameters align across Params <-> RuleSet transitions. - EndpointRuleSet ruleSet = serviceShape.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet(); - errors.addAll(validateParametersMatching(serviceShape, - ruleSet.getParameters(), - errorsParamsPair.getRight())); - // Check that tests declare required parameters, only defined parameters, etc. - if (serviceShape.hasTrait(EndpointTestsTrait.ID)) { - errors.addAll(validateTestsParameters( - serviceShape, - serviceShape.expectTrait(EndpointTestsTrait.class), - ruleSet)); + for (ServiceShape service : model.getServiceShapes()) { + EndpointRuleSetTrait epTrait = service.getTrait(EndpointRuleSetTrait.class).orElse(null); + EndpointBddTrait bddTrait = service.getTrait(EndpointBddTrait.class).orElse(null); + if (epTrait != null) { + validate(model, topDownIndex, service, errors, epTrait, epTrait.getEndpointRuleSet().getParameters()); + } + if (bddTrait != null) { + validate(model, topDownIndex, service, errors, bddTrait, bddTrait.getParameters()); } } return errors; } - private Pair, Map> validateAndExtractParameters( + private void validate( Model model, - ServiceShape serviceShape, + TopDownIndex topDownIndex, + ServiceShape service, + List errors, + FromSourceLocation sourceLocation, + Iterable parameters + ) { + // Pull all the parameters used in this service related to endpoints, validating that + // they are of matching types across the traits that can define them. + Set operations = topDownIndex.getContainedOperations(service); + Map modelParams = validateAndExtractParameters(errors, model, service, operations); + // Make sure parameters align across Params <-> RuleSet transitions. + validateParametersMatching(errors, service, sourceLocation, parameters, modelParams); + // Check that tests declare required parameters, only defined parameters, etc. + if (service.hasTrait(EndpointTestsTrait.ID)) { + validateTestsParameters(errors, service, service.expectTrait(EndpointTestsTrait.class), parameters); + } + } + + private Map validateAndExtractParameters( + List errors, + Model model, + ServiceShape service, Set containedOperations ) { - List errors = new ArrayList<>(); Map endpointParams = new HashMap<>(); - if (serviceShape.hasTrait(ClientContextParamsTrait.ID)) { - ClientContextParamsTrait trait = serviceShape.expectTrait(ClientContextParamsTrait.class); + if (service.hasTrait(ClientContextParamsTrait.ID)) { + ClientContextParamsTrait trait = service.expectTrait(ClientContextParamsTrait.class); for (Map.Entry entry : trait.getParameters().entrySet()) { endpointParams.put(entry.getKey(), Parameter.builder() @@ -120,10 +125,14 @@ private Pair, Map> validateAndExtractPa if (operationShape.hasTrait(OperationContextParamsTrait.ID)) { OperationContextParamsTrait trait = operationShape.expectTrait(OperationContextParamsTrait.class); - trait.getParameters().forEach((name, p) -> { - Optional maybeType = OperationContextParamsChecker - .inferParameterType(p, operationShape, model); - maybeType.ifPresent(parameterType -> { + for (Map.Entry entry : trait.getParameters().entrySet()) { + String name = entry.getKey(); + OperationContextParamDefinition p = entry.getValue(); + ParameterType parameterType = OperationContextParamsChecker + .inferParameterType(p, operationShape, model) + .orElse(null); + + if (parameterType != null) { if (endpointParams.containsKey(name) && endpointParams.get(name).getType() != parameterType) { errors.add(parameterError(operationShape, trait, @@ -136,8 +145,8 @@ private Pair, Map> validateAndExtractPa .type(parameterType) .build()); } - }); - }); + } + } } StructureShape input = model.expectShape(operationShape.getInputShape(), StructureShape.class); @@ -173,15 +182,16 @@ private Pair, Map> validateAndExtractPa } } - return Pair.of(errors, endpointParams); + return endpointParams; } - private List validateParametersMatching( + private void validateParametersMatching( + List errors, ServiceShape serviceShape, - Parameters ruleSetParams, + FromSourceLocation sourceLocation, + Iterable ruleSetParams, Map modelParams ) { - List errors = new ArrayList<>(); Set matchedParams = new HashSet<>(); for (Parameter parameter : ruleSetParams) { String name = parameter.getName().toString(); @@ -213,27 +223,25 @@ private List validateParametersMatching( for (Map.Entry entry : modelParams.entrySet()) { if (!matchedParams.contains(entry.getKey())) { errors.add(parameterError(serviceShape, - serviceShape.expectTrait(EndpointRuleSetTrait.class), + sourceLocation, "RuleSet.UnmatchedName", String.format("Parameter `%s` exists in service model but not in ruleset, existing params: %s", entry.getKey(), matchedParams))); } } - - return errors; } - private List validateTestsParameters( + private void validateTestsParameters( + List errors, ServiceShape serviceShape, EndpointTestsTrait trait, - EndpointRuleSet ruleSet + Iterable parameters ) { - List errors = new ArrayList<>(); Set rulesetParamNames = new HashSet<>(); Map> testSuiteParams = extractTestSuiteParameters(trait.getTestCases()); - for (Parameter parameter : ruleSet.getParameters()) { + for (Parameter parameter : parameters) { String name = parameter.getName().toString(); rulesetParamNames.add(name); boolean testSuiteHasParam = testSuiteParams.containsKey(name); @@ -278,8 +286,6 @@ private List validateTestsParameters( String.format("Test parameter `%s` is not defined in ruleset", entry.getKey()))); } } - - return errors; } private Map> extractTestSuiteParameters(List testCases) { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetTestCaseValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetTestCaseValidator.java index bfccecaf19c..33fd7c87b08 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetTestCaseValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetTestCaseValidator.java @@ -12,6 +12,7 @@ import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.evaluation.TestEvaluator; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; import software.amazon.smithy.rulesengine.traits.EndpointTestCase; import software.amazon.smithy.rulesengine.traits.EndpointTestsTrait; @@ -23,22 +24,38 @@ public class RuleSetTestCaseValidator extends AbstractValidator { @Override public List validate(Model model) { List events = new ArrayList<>(); - for (ServiceShape serviceShape : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { - if (serviceShape.hasTrait(EndpointTestsTrait.ID)) { - EndpointRuleSet ruleSet = serviceShape.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet(); - EndpointTestsTrait testsTrait = serviceShape.expectTrait(EndpointTestsTrait.class); - - // Test/Rule evaluation throws RuntimeExceptions when evaluating, wrap these - // up into ValidationEvents for automatic validation. - for (EndpointTestCase endpointTestCase : testsTrait.getTestCases()) { - try { - TestEvaluator.evaluate(ruleSet, endpointTestCase); - } catch (RuntimeException e) { - events.add(error(serviceShape, endpointTestCase, e.getMessage())); - } - } + for (ServiceShape serviceShape : model.getServiceShapesWithTrait(EndpointTestsTrait.class)) { + EndpointTestsTrait testsTrait = serviceShape.expectTrait(EndpointTestsTrait.class); + if (serviceShape.hasTrait(EndpointRuleSetTrait.class)) { + validate(serviceShape, testsTrait, events); + } else if (serviceShape.hasTrait(EndpointBddTrait.class)) { + validateBdd(serviceShape, testsTrait, events); } } return events; } + + // Test/Rule evaluation throws RuntimeExceptions when evaluating, wrap these + // up into ValidationEvents for automatic validation. + private void validate(ServiceShape serviceShape, EndpointTestsTrait testsTrait, List events) { + EndpointRuleSet ruleSet = serviceShape.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet(); + for (EndpointTestCase endpointTestCase : testsTrait.getTestCases()) { + try { + TestEvaluator.evaluate(ruleSet, endpointTestCase); + } catch (RuntimeException e) { + events.add(error(serviceShape, endpointTestCase, e.getMessage())); + } + } + } + + private void validateBdd(ServiceShape serviceShape, EndpointTestsTrait testsTrait, List events) { + EndpointBddTrait trait = serviceShape.expectTrait(EndpointBddTrait.class); + for (EndpointTestCase endpointTestCase : testsTrait.getTestCases()) { + try { + TestEvaluator.evaluate(trait, endpointTestCase); + } catch (RuntimeException e) { + events.add(error(serviceShape, endpointTestCase, e.getMessage())); + } + } + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetUriValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetUriValidator.java index c78a3a5b04a..0ebfc6a863f 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetUriValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetUriValidator.java @@ -6,22 +6,19 @@ import java.util.ArrayList; import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; -import java.util.stream.Stream; import software.amazon.smithy.model.Model; import software.amazon.smithy.model.shapes.ServiceShape; import software.amazon.smithy.model.validation.AbstractValidator; import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.language.Endpoint; -import software.amazon.smithy.rulesengine.language.TraversingVisitor; -import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; -import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; -import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.LiteralVisitor; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; -import software.amazon.smithy.utils.OptionalUtils; import software.amazon.smithy.utils.SmithyUnstableApi; /** @@ -32,75 +29,60 @@ public final class RuleSetUriValidator extends AbstractValidator { @Override public List validate(Model model) { List events = new ArrayList<>(); - for (ServiceShape serviceShape : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { - events.addAll(new UriSchemeVisitor(serviceShape) - .visitRuleset(serviceShape.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet()) - .collect(Collectors.toList())); + for (ServiceShape serviceShape : model.getServiceShapes()) { + visitRuleset(events, serviceShape, serviceShape.getTrait(EndpointRuleSetTrait.class).orElse(null)); + visitBdd(events, serviceShape, serviceShape.getTrait(EndpointBddTrait.class).orElse(null)); } return events; } - private final class UriSchemeVisitor extends TraversingVisitor { - private final ServiceShape serviceShape; - private boolean checkingEndpoint = false; - - UriSchemeVisitor(ServiceShape serviceShape) { - this.serviceShape = serviceShape; - } - - @Override - public Stream visitEndpoint(Endpoint endpoint) { - checkingEndpoint = true; - Stream errors = endpoint.getUrl().accept(this); - checkingEndpoint = false; - return errors; + private void visitRuleset(List events, ServiceShape serviceShape, EndpointRuleSetTrait trait) { + if (trait != null) { + for (Rule rule : trait.getEndpointRuleSet().getRules()) { + traverse(events, serviceShape, rule); + } } + } - @Override - public Stream visitLiteral(Literal literal) { - return literal.accept(new LiteralVisitor>() { - @Override - public Stream visitBoolean(boolean b) { - return Stream.empty(); - } - - @Override - public Stream visitString(Template value) { - return OptionalUtils.stream(validateTemplate(value)); - } - - @Override - public Stream visitRecord(Map members) { - return Stream.empty(); + private void visitBdd(List events, ServiceShape serviceShape, EndpointBddTrait trait) { + if (trait != null) { + for (Rule result : trait.getResults()) { + if (result instanceof EndpointRule) { + visitEndpoint(events, serviceShape, (EndpointRule) result); } + } + } + } - @Override - public Stream visitTuple(List members) { - return Stream.empty(); - } + private void traverse(List events, ServiceShape service, Rule rule) { + if (rule instanceof EndpointRule) { + visitEndpoint(events, service, (EndpointRule) rule); + } else if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + for (Rule child : treeRule.getRules()) { + traverse(events, service, child); + } + } + } - @Override - public Stream visitInteger(int value) { - return Stream.empty(); - } - }); + private void visitEndpoint(List events, ServiceShape serviceShape, EndpointRule endpointRule) { + Endpoint endpoint = endpointRule.getEndpoint(); + Expression url = endpoint.getUrl(); + if (url instanceof StringLiteral) { + StringLiteral s = (StringLiteral) url; + visitTemplate(events, serviceShape, s.value()); } + } - private Optional validateTemplate(Template template) { - if (checkingEndpoint) { - Template.Part part = template.getParts().get(0); - if (part instanceof Template.Literal) { - String scheme = ((Template.Literal) part).getValue(); - if (!(scheme.startsWith("http://") || scheme.startsWith("https://"))) { - return Optional.of(error(serviceShape, - template, - "URI should start with `http://` or `https://` but the URI started with " - + scheme)); - } - } - // Allow dynamic URIs for now — we should lint that at looks like a scheme at some point + private void visitTemplate(List events, ServiceShape serviceShape, Template template) { + Template.Part part = template.getParts().get(0); + if (part instanceof Template.Literal) { + String scheme = ((Template.Literal) part).getValue(); + if (!(scheme.startsWith("http://") || scheme.startsWith("https://"))) { + events.add(error(serviceShape, + template, + "URI should start with `http://` or `https://` but the URI started with " + scheme)); } - return Optional.empty(); } } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RulesEngineVersionValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RulesEngineVersionValidator.java new file mode 100644 index 00000000000..65d2bdb0c70 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RulesEngineVersionValidator.java @@ -0,0 +1,113 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.validators; + +import java.util.ArrayList; +import java.util.List; +import software.amazon.smithy.model.FromSourceLocation; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.SourceLocation; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.validation.AbstractValidator; +import software.amazon.smithy.model.validation.ValidationEvent; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.RulesVersion; +import software.amazon.smithy.rulesengine.language.syntax.SyntaxElement; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; +import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; + +/** + * Validates that the rules engine version of a trait only uses compatible features. + */ +public final class RulesEngineVersionValidator extends AbstractValidator { + + @Override + public List validate(Model model) { + List events = new ArrayList<>(); + + for (ServiceShape s : model.getServiceShapesWithTrait(EndpointBddTrait.class)) { + validateBdd(events, s, s.expectTrait(EndpointBddTrait.class)); + } + + for (ServiceShape s : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { + validateTree(events, s, s.expectTrait(EndpointRuleSetTrait.class)); + } + + return events; + } + + private void validateBdd(List events, ServiceShape service, EndpointBddTrait trait) { + RulesVersion version = trait.getVersion(); + + for (Condition condition : trait.getConditions()) { + validateSyntaxElement(events, service, condition, version); + } + + for (Rule result : trait.getResults()) { + validateRule(events, service, result, version); + } + } + + private void validateTree(List events, ServiceShape service, EndpointRuleSetTrait trait) { + EndpointRuleSet rules = trait.getEndpointRuleSet(); + RulesVersion version = rules.getRulesVersion(); + for (Rule rule : rules.getRules()) { + validateRule(events, service, rule, version); + } + } + + private void validateRule(List events, ServiceShape service, Rule rule, RulesVersion version) { + for (Condition condition : rule.getConditions()) { + validateSyntaxElement(events, service, condition, version); + validateSyntaxElement(events, service, condition.getFunction(), version); + for (Expression arg : condition.getFunction().getArguments()) { + validateSyntaxElement(events, service, arg, version); + } + } + + if (rule instanceof TreeRule) { + for (Rule nestedRule : ((TreeRule) rule).getRules()) { + validateRule(events, service, nestedRule, version); + } + } else if (rule instanceof EndpointRule) { + EndpointRule endpointRule = (EndpointRule) rule; + validateSyntaxElement(events, service, endpointRule.getEndpoint().getUrl(), version); + for (List headerValues : endpointRule.getEndpoint().getHeaders().values()) { + for (Expression expr : headerValues) { + validateSyntaxElement(events, service, expr, version); + } + } + } else if (rule instanceof ErrorRule) { + validateSyntaxElement(events, service, ((ErrorRule) rule).getError(), version); + } + } + + private void validateSyntaxElement( + List events, + ServiceShape service, + SyntaxElement element, + RulesVersion declaredVersion + ) { + RulesVersion requiredVersion = element.availableSince(); + + if (!declaredVersion.isAtLeast(requiredVersion)) { + SourceLocation s = element instanceof FromSourceLocation + ? ((FromSourceLocation) element).getSourceLocation() + : element.toExpression().getSourceLocation(); + String msg = String.format( + "%s requires rules engine version >= %s, but ruleset declares version %s", + element.getClass().getSimpleName(), + requiredVersion, + declaredVersion); + events.add(error(service, s, msg)); + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/StaticContextParamsTraitValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/StaticContextParamsTraitValidator.java similarity index 89% rename from smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/StaticContextParamsTraitValidator.java rename to smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/StaticContextParamsTraitValidator.java index de3a344c261..8b01f796cc4 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/StaticContextParamsTraitValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/StaticContextParamsTraitValidator.java @@ -2,7 +2,7 @@ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ -package software.amazon.smithy.rulesengine.traits; +package software.amazon.smithy.rulesengine.validators; import java.util.ArrayList; import java.util.Collections; @@ -13,6 +13,9 @@ import software.amazon.smithy.model.shapes.OperationShape; import software.amazon.smithy.model.validation.AbstractValidator; import software.amazon.smithy.model.validation.ValidationEvent; +import software.amazon.smithy.rulesengine.traits.ContextIndex; +import software.amazon.smithy.rulesengine.traits.StaticContextParamDefinition; +import software.amazon.smithy.rulesengine.traits.StaticContextParamsTrait; import software.amazon.smithy.utils.SmithyUnstableApi; /** diff --git a/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService b/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService index 353f3b95e8e..0edd3cd6758 100644 --- a/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService +++ b/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService @@ -4,3 +4,4 @@ software.amazon.smithy.rulesengine.traits.StaticContextParamsTrait$Provider software.amazon.smithy.rulesengine.traits.OperationContextParamsTrait$Provider software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait$Provider software.amazon.smithy.rulesengine.traits.EndpointTestsTrait$Provider +software.amazon.smithy.rulesengine.traits.EndpointBddTrait$Provider diff --git a/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.validation.Validator b/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.validation.Validator index 2f24c94156b..18ef8dab55c 100644 --- a/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.validation.Validator +++ b/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.validation.Validator @@ -1,9 +1,11 @@ -software.amazon.smithy.rulesengine.traits.EndpointTestsTraitValidator -software.amazon.smithy.rulesengine.traits.StaticContextParamsTraitValidator -software.amazon.smithy.rulesengine.traits.OperationContextParamsTraitValidator +software.amazon.smithy.rulesengine.validators.EndpointTestsTraitValidator +software.amazon.smithy.rulesengine.validators.StaticContextParamsTraitValidator +software.amazon.smithy.rulesengine.validators.OperationContextParamsTraitValidator software.amazon.smithy.rulesengine.validators.RuleSetAuthSchemesValidator software.amazon.smithy.rulesengine.validators.RuleSetBuiltInValidator software.amazon.smithy.rulesengine.validators.RuleSetUriValidator software.amazon.smithy.rulesengine.validators.RuleSetParamMissingDocsValidator software.amazon.smithy.rulesengine.validators.RuleSetParameterValidator software.amazon.smithy.rulesengine.validators.RuleSetTestCaseValidator +software.amazon.smithy.rulesengine.validators.BddTraitValidator +software.amazon.smithy.rulesengine.validators.RulesEngineVersionValidator diff --git a/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy b/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy index c475281bc3f..65a480c49e7 100644 --- a/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy +++ b/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy @@ -7,9 +7,196 @@ namespace smithy.rules @trait(selector: "service") document endpointRuleSet +/// Defines an endpoint rule-set using a binary decision diagram (BDD). +@unstable +@trait(selector: "service") +structure endpointBdd { + /// The rules engine version. Must be set to 1.1 or higher. + @required + version: String + + /// A map of zero or more endpoint parameter names to their parameter configuration. + @required + parameters: Parameters + + /// An ordered list of unique conditions used throughout the BDD. + @required + conditions: Conditions + + /// An ordered list of results referenced by BDD nodes. The first result is always the terminal node. + @required + results: Results + + /// The root node of where to start evaluating the BDD. + @required + @range(min: -1) + root: Integer + + /// The number of nodes contained in the BDD. + @required + @range(min: 0) + nodeCount: Integer + + /// Base64-encoded array of BDD nodes representing the decision graph structure. + /// + /// All integers are encoded in big-endian. + /// + /// The first node (index 0) is always the terminal node `[-1, 1, -1]` and is included in the nodeCount. + /// User-defined nodes start at index 1. + /// + /// Each node is written one after the other and consists of three integers written sequentially: + /// 1. variable index + /// 2. high reference (when condition is true) + /// 3. low reference (when condition is false) + /// + /// Node Structure [variable, high, low]: + /// - variable: The index of the condition being tested (0 to conditionCount-1) + /// - high: Reference to follow when the condition evaluates to true + /// - low: Reference to follow when the condition evaluates to false + /// + /// Reference Encoding: + /// - 0: Invalid/unused reference (never appears in valid BDDs) + /// - 1: TRUE terminal (treated as "no match" in endpoint resolution) + /// - -1: FALSE terminal (treated as "no match" in endpoint resolution) + /// - 2, 3, 4, ...: Node references pointing to nodes[ref-1] + /// - -2, -3, -4, ...: Complement node references (logical NOT of nodes[abs(ref)-1]) + /// - 100000000+: Result terminals (100000000 + resultIndex) + /// + /// Complement edges: + /// A negative reference represents the logical NOT of the referenced node's entire subgraph. So `-5` means the + /// complement of node 5 (located in the array at index 4, since `index = |ref| - 1`). In this case, evaluate the + /// condition referenced by node 4, and if it is TRUE, use the low reference, and if it's FALSE, use the high + /// reference. This optimization significantly reduces BDD size by allowing a single subgraph to represent both a + /// boolean function and its complement; instead of creating separate nodes for `condition AND other` and + /// `NOT(condition AND other)`, we can reuse the same nodes with complement edges. Complement edges cannot be + /// used on result terminals. + @required + nodes: String +} + +@private +map Parameters { + key: String + value: Parameter +} + +/// A rules input parameter. +@private +structure Parameter { + /// The parameter type. + @required + type: ParameterType + + /// True if the parameter is deprecated. + deprecated: Boolean + + /// Documentation about the parameter. + documentation: String + + /// Specifies the default value for the parameter if not set. + /// Parameters with defaults MUST also be marked as required. The type of the provided default MUST match type. + default: Document + + /// Specifies a named built-in value that is sourced and provided to the endpoint provider by a caller. + builtIn: String + + /// Specifies that the parameter is required to be provided to the endpoint provider. + required: Boolean +} + +/// The kind of parameter. +enum ParameterType { + STRING = "string" + BOOLEAN = "boolean" + STRING_ARRAY = "stringArray" +} + +@private +list Conditions { + member: Condition +} + +@private +structure Condition { + /// The name of the function to be executed. + @required + fn: String + + /// The arguments for the function. + /// An array of one or more of the following types: string, bool, array, Reference object, or Function object + @required + argv: DocumentList + + /// The optional destination variable to assign the functions result to. + assign: String +} + +@private +list DocumentList { + member: Document +} + +@private +list Results { + member: Result +} + +@private +structure Result { + /// Result type. + @required + type: ResultType + + /// An optional description of the result. + documentation: String + + /// Provided if type is "error". + error: Document + + /// Provided if type is "endpoint". + endpoint: EndpointObject + + /// Conditions for the result (only used with decision tree rules). + conditions: Conditions +} + +@private +enum ResultType { + ENDPOINT = "endpoint" + ERROR = "error" +} + +@private +structure EndpointObject { + /// The endpoint url. This MUST specify a scheme and hostname and MAY contain port and base path components. + /// A string value MAY be a Template string. Any value for this property MUST resolve to a string. + @required + url: Document + + /// A map containing zero or more key value property pairs. Endpoint properties MAY be arbitrarily deep and + /// contain other maps and arrays. + properties: EndpointProperties + + /// A map of transport header names to their respective values. A string value in an array MAY be a + /// template string. + headers: EndpointObjectHeaders +} + +@private +map EndpointProperties { + key: String + value: Document +} + +@private +map EndpointObjectHeaders { + key: String + value: DocumentList +} + /// Defines endpoint test-cases for validating a client's endpoint rule-set. @unstable -@trait(selector: "service[trait|smithy.rules#endpointRuleSet]") +@trait(selector: "service :is([trait|smithy.rules#endpointRuleSet], [trait|smithy.rules#endpointBdd])") structure endpointTests { /// The endpoint tests schema version. @required diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/analysis/BddCoverageCheckerTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/analysis/BddCoverageCheckerTest.java new file mode 100644 index 00000000000..1e1a7868d41 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/analysis/BddCoverageCheckerTest.java @@ -0,0 +1,278 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.analysis; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.model.node.ObjectNode; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.evaluation.value.Value; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; +import software.amazon.smithy.rulesengine.traits.EndpointTestCase; +import software.amazon.smithy.rulesengine.traits.EndpointTestExpectation; +import software.amazon.smithy.rulesengine.traits.ExpectedEndpoint; +import software.amazon.smithy.utils.ListUtils; +import software.amazon.smithy.utils.MapUtils; + +public class BddCoverageCheckerTest { + + private EndpointBddTrait bddTrait; + private BddCoverageChecker checker; + + @BeforeEach + void setUp() { + // Create a simple ruleset with multiple conditions and results + Parameters parameters = Parameters.builder() + .addParameter(Parameter.builder() + .name("Region") + .type(ParameterType.STRING) + .required(false) + .build()) + .addParameter(Parameter.builder() + .name("UseFips") + .type(ParameterType.BOOLEAN) + .required(true) + .defaultValue(Value.booleanValue(false)) + .build()) + .addParameter(Parameter.builder() + .name("UseDualStack") + .type(ParameterType.BOOLEAN) + .required(true) + .defaultValue(Value.booleanValue(false)) + .build()) + .build(); + + List rules = new ArrayList<>(); + + // Rule 1: If Region is set and UseFips is true, error + rules.add(ErrorRule.builder() + .conditions(ListUtils.of( + Condition.builder() + .fn(IsSet.ofExpressions(Expression.getReference("Region"))) + .build(), + Condition.builder() + .fn(BooleanEquals.ofExpressions(Expression.getReference("UseFips"), true)) + .build())) + .error("FIPS not supported in this region")); + + // Rule 2: If Region is set and UseDualStack is true, specific endpoint + rules.add(EndpointRule.builder() + .conditions(ListUtils.of( + Condition.builder() + .fn(IsSet.ofExpressions(Expression.getReference("Region"))) + .build(), + Condition.builder() + .fn(BooleanEquals.ofExpressions(Expression.getReference("UseDualStack"), true)) + .build())) + .endpoint(TestHelpers.endpoint("https://dualstack.example.com"))); + + // Rule 3: Default endpoint + rules.add(EndpointRule.builder().endpoint(TestHelpers.endpoint("https://empty.com"))); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(parameters) + .rules(rules) + .version("1.1") + .build(); + + // Convert to BDD + Cfg cfg = Cfg.from(ruleSet); + bddTrait = EndpointBddTrait.from(cfg); + checker = new BddCoverageChecker(bddTrait); + } + + @Test + void testInitialCoverage() { + assertEquals(0.0, checker.getConditionCoverage()); + assertEquals(0.0, checker.getResultCoverage()); + } + + @Test + void testSingleTestCaseCoverage() { + // Test with Region set and UseFips true + EndpointTestCase testCase = EndpointTestCase.builder() + .params(ObjectNode.builder() + .withMember("Region", "us-east-1") + .withMember("UseFips", true) + .withMember("UseDualStack", false) + .build()) + .expect(EndpointTestExpectation.builder() + .error("FIPS not supported in this region") + .build()) + .build(); + + checker.evaluateTestCase(testCase); + + // Should have covered some conditions + assertTrue(checker.getConditionCoverage() > 0.0); + assertTrue(checker.getResultCoverage() > 0.0); + + // Should have fewer unevaluated conditions + Set unevaluatedConditions = checker.getUnevaluatedConditions(); + assertTrue(unevaluatedConditions.size() < bddTrait.getConditions().size()); + } + + @Test + void testMultipleTestCasesCoverage() { + // Test case 1: FIPS error path + checker.evaluateTestCase(EndpointTestCase.builder() + .params(ObjectNode.builder() + .withMember("Region", "us-east-1") + .withMember("UseFips", true) + .withMember("UseDualStack", false) + .build()) + .expect(EndpointTestExpectation.builder() + .error("FIPS not supported in this region") + .build()) + .build()); + + double coverageAfterFirst = checker.getConditionCoverage(); + + // Test case 2: Dual stack path + checker.evaluateTestCase(EndpointTestCase.builder() + .params(ObjectNode.builder() + .withMember("Region", "us-east-1") + .withMember("UseFips", false) + .withMember("UseDualStack", true) + .build()) + .expect(EndpointTestExpectation.builder() + .endpoint(ExpectedEndpoint.builder() + .url("https://dualstack.example.com") + .build()) + .build()) + .build()); + + double coverageAfterSecond = checker.getConditionCoverage(); + + // Coverage should increase or stay the same + assertTrue(coverageAfterSecond >= coverageAfterFirst); + + // Test case 3: Default endpoint + checker.evaluateTestCase(EndpointTestCase.builder() + .params(ObjectNode.builder() + .withMember("UseFips", false) + .withMember("UseDualStack", false) + .build()) + .expect(EndpointTestExpectation.builder() + .endpoint(ExpectedEndpoint.builder() + .url("https://example.com") + .build()) + .build()) + .build()); + + double coverageAfterThird = checker.getConditionCoverage(); + assertTrue(coverageAfterThird >= coverageAfterSecond); + } + + @Test + void testEvaluateInput() { + // Test direct input evaluation + checker.evaluateInput(MapUtils.of( + Identifier.of("Region"), + Value.stringValue("us-west-2"), + Identifier.of("UseFips"), + Value.booleanValue(false), + Identifier.of("UseDualStack"), + Value.booleanValue(true))); + + assertTrue(checker.getConditionCoverage() > 0.0); + assertTrue(checker.getResultCoverage() > 0.0); + } + + @Test + void testFullCoverage() { + List testCases = ListUtils.of( + // FIPS error + EndpointTestCase.builder() + .params(ObjectNode.builder() + .withMember("Region", "us-east-1") + .withMember("UseFips", true) + .build()) + .expect(EndpointTestExpectation.builder() + .error("FIPS not supported in this region") + .build()) + .build(), + // Dual stack + EndpointTestCase.builder() + .params(ObjectNode.builder() + .withMember("Region", "us-east-1") + .withMember("UseDualStack", true) + .build()) + .expect(EndpointTestExpectation.builder() + .endpoint(ExpectedEndpoint.builder() + .url("https://dualstack.example.com") + .build()) + .build()) + .build(), + // Default with region + EndpointTestCase.builder() + .params(ObjectNode.builder() + .withMember("Region", "us-east-1") + .withMember("UseFips", false) + .withMember("UseDualStack", false) + .build()) + .expect(EndpointTestExpectation.builder() + .endpoint(ExpectedEndpoint.builder() + .url("https://example.com") + .build()) + .build()) + .build(), + // Default without region + EndpointTestCase.builder() + .params(ObjectNode.objectNode()) + .expect(EndpointTestExpectation.builder() + .endpoint(ExpectedEndpoint.builder() + .url("https://example.com") + .build()) + .build()) + .build()); + + for (EndpointTestCase testCase : testCases) { + checker.evaluateTestCase(testCase); + } + + assertEquals(100.0, checker.getConditionCoverage()); + assertEquals(100.0, checker.getResultCoverage()); + } + + @Test + void testEmptyBdd() { + Parameters emptyParams = Parameters.builder().build(); + EndpointRuleSet emptyRuleSet = EndpointRuleSet.builder() + .parameters(emptyParams) + .rules(ListUtils.of(EndpointRule.builder().endpoint(TestHelpers.endpoint("https://empty.com")))) + .version("1.1") + .build(); + + Cfg emptyCfg = Cfg.from(emptyRuleSet); + EndpointBddTrait emptyBdd = EndpointBddTrait.from(emptyCfg); + BddCoverageChecker emptyChecker = new BddCoverageChecker(emptyBdd); + + assertEquals(100.0, emptyChecker.getConditionCoverage(), 0.01); + assertEquals(0.0, emptyChecker.getResultCoverage(), 0.01); + + emptyChecker.evaluateInput(MapUtils.of()); + assertTrue(emptyChecker.getResultCoverage() > 0.0); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java new file mode 100644 index 00000000000..bf6ac4bb9da --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java @@ -0,0 +1,218 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.language.syntax.functions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Arrays; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.evaluation.Scope; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; + +public class CoalesceTest { + + @Test + void testCoalesceWithTwoSameTypes() { + Expression left = Literal.of("default"); + Expression right = Literal.of("fallback"); + Coalesce coalesce = Coalesce.ofExpressions(left, right); + + Scope scope = new Scope<>(); + Type resultType = coalesce.typeCheck(scope); + + assertEquals(Type.stringType(), resultType); + } + + @Test + void testCoalesceWithThreeSameTypes() { + Expression first = Literal.of("first"); + Expression second = Literal.of("second"); + Expression third = Literal.of("third"); + Coalesce coalesce = Coalesce.ofExpressions(first, second, third); + + Scope scope = new Scope<>(); + Type resultType = coalesce.typeCheck(scope); + + assertEquals(Type.stringType(), resultType); + } + + @Test + void testCoalesceVariadicWithList() { + Expression first = Literal.of(1); + Expression second = Literal.of(2); + Expression third = Literal.of(3); + Expression fourth = Literal.of(4); + + Coalesce coalesce = Coalesce.ofExpressions(Arrays.asList(first, second, third, fourth)); + + Scope scope = new Scope<>(); + Type resultType = coalesce.typeCheck(scope); + + assertEquals(Type.integerType(), resultType); + } + + @Test + void testCoalesceWithOptionalLeft() { + Expression optionalVar = Expression.getReference(Identifier.of("maybeValue")); + Expression fallback = Literal.of("default"); + Coalesce coalesce = Coalesce.ofExpressions(optionalVar, fallback); + + Scope scope = new Scope<>(); + scope.insert("maybeValue", Type.optionalType(Type.stringType())); + + Type resultType = coalesce.typeCheck(scope); + + // Should unwrap optional and return non-optional String + assertEquals(Type.stringType(), resultType); + } + + @Test + void testCoalesceWithAllOptional() { + Expression var1 = Expression.getReference(Identifier.of("maybe1")); + Expression var2 = Expression.getReference(Identifier.of("maybe2")); + Coalesce coalesce = Coalesce.ofExpressions(var1, var2); + + Scope scope = new Scope<>(); + scope.insert("maybe1", Type.optionalType(Type.stringType())); + scope.insert("maybe2", Type.optionalType(Type.stringType())); + + Type resultType = coalesce.typeCheck(scope); + + // Both optional means result is optional + assertEquals(Type.optionalType(Type.stringType()), resultType); + } + + @Test + void testCoalesceThreeWithAllOptional() { + Expression var1 = Expression.getReference(Identifier.of("maybe1")); + Expression var2 = Expression.getReference(Identifier.of("maybe2")); + Expression var3 = Expression.getReference(Identifier.of("maybe3")); + Coalesce coalesce = Coalesce.ofExpressions(var1, var2, var3); + + Scope scope = new Scope<>(); + scope.insert("maybe1", Type.optionalType(Type.integerType())); + scope.insert("maybe2", Type.optionalType(Type.integerType())); + scope.insert("maybe3", Type.optionalType(Type.integerType())); + + Type resultType = coalesce.typeCheck(scope); + + // All optional means result is optional + assertEquals(Type.optionalType(Type.integerType()), resultType); + } + + @Test + void testCoalesceMixedOptionalAndNonOptional() { + Expression optional1 = Expression.getReference(Identifier.of("optional1")); + Expression required = Expression.getReference(Identifier.of("required")); + Expression optional2 = Expression.getReference(Identifier.of("optional2")); + + Coalesce coalesce = Coalesce.ofExpressions(optional1, required, optional2); + + Scope scope = new Scope<>(); + scope.insert("optional1", Type.optionalType(Type.stringType())); + scope.insert("required", Type.stringType()); + scope.insert("optional2", Type.optionalType(Type.stringType())); + + Type resultType = coalesce.typeCheck(scope); + + // Any non-optional in the chain makes result non-optional + assertEquals(Type.stringType(), resultType); + } + + @Test + void testCoalesceWithIncompatibleTypes() { + Expression stringExpr = Literal.of("text"); + Expression intExpr = Literal.of(42); + Coalesce coalesce = Coalesce.ofExpressions(stringExpr, intExpr); + + Scope scope = new Scope<>(); + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> coalesce.typeCheck(scope)); + assertTrue(ex.getMessage().contains("Type mismatch in coalesce")); + assertTrue(ex.getMessage().contains("argument 2")); + } + + @Test + void testCoalesceWithIncompatibleTypesInMiddle() { + Expression int1 = Literal.of(1); + Expression int2 = Literal.of(2); + Expression string = Literal.of("oops"); + Expression int3 = Literal.of(3); + + Coalesce coalesce = Coalesce.ofExpressions(int1, int2, string, int3); + + Scope scope = new Scope<>(); + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> coalesce.typeCheck(scope)); + assertTrue(ex.getMessage().contains("Type mismatch in coalesce")); + assertTrue(ex.getMessage().contains("argument 3")); + } + + @Test + void testCoalesceWithLessThanTwoArguments() { + Expression single = Literal.of("only"); + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, + () -> Coalesce.ofExpressions(single).typeCheck(new Scope<>())); + assertTrue(ex.getMessage().contains("at least 2 arguments")); + } + + @Test + void testCoalesceArrayTypes() { + Expression arr1 = Expression.getReference(Identifier.of("array1")); + Expression arr2 = Expression.getReference(Identifier.of("array2")); + Expression arr3 = Expression.getReference(Identifier.of("array3")); + Coalesce coalesce = Coalesce.ofExpressions(arr1, arr2, arr3); + + Scope scope = new Scope<>(); + scope.insert("array1", Type.arrayType(Type.stringType())); + scope.insert("array2", Type.arrayType(Type.stringType())); + scope.insert("array3", Type.arrayType(Type.stringType())); + + Type resultType = coalesce.typeCheck(scope); + + assertEquals(Type.arrayType(Type.stringType()), resultType); + } + + @Test + void testCoalesceOptionalArrayTypes() { + Expression arr1 = Expression.getReference(Identifier.of("array1")); + Expression arr2 = Expression.getReference(Identifier.of("array2")); + Coalesce coalesce = Coalesce.ofExpressions(arr1, arr2); + + Scope scope = new Scope<>(); + scope.insert("array1", Type.optionalType(Type.arrayType(Type.integerType()))); + scope.insert("array2", Type.arrayType(Type.integerType())); + + Type resultType = coalesce.typeCheck(scope); + + // One non-optional makes result non-optional + assertEquals(Type.arrayType(Type.integerType()), resultType); + } + + @Test + void testCoalesceWithBooleanTypes() { + Expression bool1 = Expression.getReference(Identifier.of("bool1")); + Expression bool2 = Expression.getReference(Identifier.of("bool2")); + Expression bool3 = Expression.getReference(Identifier.of("bool3")); + + Coalesce coalesce = Coalesce.ofExpressions(bool1, bool2, bool3); + + Scope scope = new Scope<>(); + scope.insert("bool1", Type.optionalType(Type.booleanType())); + scope.insert("bool2", Type.optionalType(Type.booleanType())); + scope.insert("bool3", Type.booleanType()); + + Type resultType = coalesce.typeCheck(scope); + + assertEquals(Type.booleanType(), resultType); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/value/ToObjectTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/value/ToObjectTest.java new file mode 100644 index 00000000000..6bec1847317 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/value/ToObjectTest.java @@ -0,0 +1,87 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.language.value; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNull; + +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.evaluation.value.Value; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.utils.ListUtils; +import software.amazon.smithy.utils.MapUtils; + +class ToObjectTest { + @Test + void testStringValueToObject() { + Value value = Value.stringValue("hello"); + + assertEquals("hello", value.toObject()); + } + + @Test + void testIntegerValueToObject() { + Value value = Value.integerValue(42); + + assertEquals(42, value.toObject()); + } + + @Test + void testBooleanValueToObject() { + assertEquals(Boolean.TRUE, Value.booleanValue(true).toObject()); + assertEquals(Boolean.FALSE, Value.booleanValue(false).toObject()); + } + + @Test + void testEmptyValueToObject() { + Value value = Value.emptyValue(); + + assertNull(value.toObject()); + } + + @Test + void testArrayValueToObject() { + Value arrayValue = Value.arrayValue(ListUtils.of( + Value.stringValue("a"), + Value.integerValue(1), + Value.booleanValue(true))); + + Object result = arrayValue.toObject(); + assertInstanceOf(List.class, result); + + List list = (List) result; + assertEquals(3, list.size()); + assertEquals("a", list.get(0)); + assertEquals(1, list.get(1)); + assertEquals(true, list.get(2)); + } + + @Test + void testRecordValueToObject() { + Map map = new LinkedHashMap<>(); + map.put(Identifier.of("name"), Value.stringValue("test")); + map.put(Identifier.of("count"), Value.integerValue(5)); + map.put(Identifier.of("enabled"), Value.booleanValue(true)); + + Value recordValue = Value.recordValue(map); + Object result = recordValue.toObject(); + + assertInstanceOf(Map.class, result); + Map resultMap = (Map) result; + assertEquals("test", resultMap.get("name")); + assertEquals(5, resultMap.get("count")); + assertEquals(true, resultMap.get("enabled")); + } + + @Test + void testEmptyCollections() { + assertEquals(ListUtils.of(), Value.arrayValue(ListUtils.of()).toObject()); + assertEquals(MapUtils.of(), Value.recordValue(MapUtils.of()).toObject()); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/RuleBasedConditionEvaluatorTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/RuleBasedConditionEvaluatorTest.java new file mode 100644 index 00000000000..34fed409058 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/RuleBasedConditionEvaluatorTest.java @@ -0,0 +1,72 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.evaluation.RuleEvaluator; +import software.amazon.smithy.rulesengine.language.evaluation.value.Value; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; + +class RuleBasedConditionEvaluatorTest { + + @Test + void testEvaluatesConditions() { + Condition cond1 = Condition.builder().fn(TestHelpers.isSet("param1")).build(); + Condition cond2 = Condition.builder().fn(TestHelpers.isSet("param2")).build(); + Condition[] conditions = {cond1, cond2}; + + // Create a mock evaluator that returns true for first condition, false for second + RuleEvaluator mockEvaluator = new RuleEvaluator() { + @Override + public Value evaluateCondition(Condition condition) { + if (condition == cond1) { + return Value.booleanValue(true); + } else { + return Value.booleanValue(false); + } + } + }; + + RuleBasedConditionEvaluator evaluator = new RuleBasedConditionEvaluator(mockEvaluator, conditions); + + assertTrue(evaluator.test(0)); + assertFalse(evaluator.test(1)); + } + + @Test + void testHandlesEmptyValue() { + Condition cond = Condition.builder().fn(TestHelpers.isSet("param")).build(); + Condition[] conditions = {cond}; + + RuleEvaluator mockEvaluator = new RuleEvaluator() { + @Override + public Value evaluateCondition(Condition condition) { + return Value.emptyValue(); + } + }; + + RuleBasedConditionEvaluator evaluator = new RuleBasedConditionEvaluator(mockEvaluator, conditions); + assertFalse(evaluator.test(0)); + } + + @Test + void testHandlesNonBooleanTruthyValue() { + Condition cond = Condition.builder().fn(TestHelpers.parseUrl("https://example.com")).build(); + Condition[] conditions = {cond}; + + RuleEvaluator mockEvaluator = new RuleEvaluator() { + @Override + public Value evaluateCondition(Condition condition) { + return Value.stringValue("some-string"); + } + }; + + RuleBasedConditionEvaluator evaluator = new RuleBasedConditionEvaluator(mockEvaluator, conditions); + assertTrue(evaluator.test(0)); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/TestHelpers.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/TestHelpers.java new file mode 100644 index 00000000000..0b5834ea7d4 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/TestHelpers.java @@ -0,0 +1,114 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic; + +import software.amazon.smithy.rulesengine.language.Endpoint; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsValidHostLabel; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Not; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.ParseUrl; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Substring; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.UriEncode; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; + +public final class TestHelpers { + + private TestHelpers() {} + + public static LibraryFunction isSet(String paramName) { + return IsSet.ofExpressions(Expression.getReference(Identifier.of(paramName))); + } + + public static LibraryFunction stringEquals(String paramName, String value) { + return StringEquals.ofExpressions( + Expression.getReference(Identifier.of(paramName)), + StringLiteral.of(value)); + } + + public static LibraryFunction stringEquals(Expression expr1, Expression expr2) { + return StringEquals.ofExpressions(expr1, expr2); + } + + public static LibraryFunction booleanEquals(String paramName, boolean value) { + return BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of(paramName)), + Literal.booleanLiteral(value)); + } + + public static LibraryFunction booleanEquals(Expression expr, boolean value) { + return BooleanEquals.ofExpressions(expr, Literal.booleanLiteral(value)); + } + + public static LibraryFunction parseUrl(String paramName) { + return ParseUrl.ofExpressions(Expression.getReference(Identifier.of(paramName))); + } + + public static LibraryFunction getAttr(Expression expr, String path) { + return GetAttr.ofExpressions(expr, Literal.of(path)); + } + + public static LibraryFunction getAttr(String paramName, String path) { + return GetAttr.ofExpressions( + Expression.getReference(Identifier.of(paramName)), + Literal.of(path)); + } + + public static LibraryFunction substring(String paramName, int start, int stop, boolean reverse) { + return Substring.ofExpressions( + Expression.getReference(Identifier.of(paramName)), + Literal.of(start), + Literal.of(stop), + Literal.of(reverse)); + } + + public static LibraryFunction substring(Expression expr, int start, int stop, boolean reverse) { + return Substring.ofExpressions( + expr, + Literal.of(start), + Literal.of(stop), + Literal.of(reverse)); + } + + public static LibraryFunction not(Expression expr) { + return Not.ofExpressions(expr); + } + + public static LibraryFunction not(LibraryFunction fn) { + return Not.ofExpressions(fn); + } + + public static LibraryFunction isValidHostLabel(String paramName, boolean allowDots) { + return IsValidHostLabel.ofExpressions( + Expression.getReference(Identifier.of(paramName)), + Literal.of(allowDots)); + } + + public static LibraryFunction isValidHostLabel(Expression expr, boolean allowDots) { + return IsValidHostLabel.ofExpressions(expr, Literal.of(allowDots)); + } + + public static LibraryFunction uriEncode(String paramName) { + return UriEncode.ofExpressions(Expression.getReference(Identifier.of(paramName))); + } + + public static LibraryFunction uriEncode(Expression expr) { + return UriEncode.ofExpressions(expr); + } + + public static Endpoint endpoint(String url) { + return Endpoint.builder().url(Expression.of(url)).build(); + } + + public static Endpoint endpoint(Expression url) { + return Endpoint.builder().url(url).build(); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilderTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilderTest.java new file mode 100644 index 00000000000..d0428d22194 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilderTest.java @@ -0,0 +1,475 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class BddBuilderTest { + + private BddBuilder builder; + + @BeforeEach + void setUp() { + builder = new BddBuilder(); + } + + @Test + void testTerminals() { + assertEquals(1, builder.makeTrue()); + assertEquals(-1, builder.makeFalse()); + + // Terminals are constants + assertEquals(1, builder.makeTrue()); + assertEquals(-1, builder.makeFalse()); + } + + @Test + void testNodeReduction() { + builder.setConditionCount(2); + + // Node with identical branches should be reduced + int reduced = builder.makeNode(0, builder.makeTrue(), builder.makeTrue()); + assertEquals(1, reduced); // Should return TRUE directly + + // Verify no new node was created (only terminal at index 0) + assertEquals(1, builder.getNodeCount()); // Only terminal node exists + } + + @Test + void testComplementCanonicalization() { + builder.setConditionCount(2); + + // Create node with complement on low branch + int node1 = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + + // Create equivalent node with flipped branches and complement + // This should canonicalize to the same node but with complement + int node2 = builder.makeNode(0, builder.makeFalse(), builder.makeTrue()); + + assertEquals(-node1, node2); // Should be complement of first node + + // Only one actual node should be created (plus terminal) + assertEquals(2, builder.getNodeCount()); // Terminal + one node + } + + @Test + void testNodeDeduplication() { + builder.setConditionCount(2); + + // Create same node twice + int node1 = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + int node2 = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + + assertEquals(node1, node2); // Should return same reference + assertEquals(2, builder.getNodeCount()); // Terminal + one node (no duplicate) + } + + @Test + void testResultNodes() { + builder.setConditionCount(2); + + int result0 = builder.makeResult(0); + int result1 = builder.makeResult(1); + + // Result nodes should have distinct references + assertTrue(result0 != result1); + assertTrue(builder.isResult(result0)); + assertTrue(builder.isResult(result1)); + assertFalse(builder.isResult(builder.makeTrue())); + + // Result refs should be encoded with RESULT_OFFSET + assertEquals(Bdd.RESULT_OFFSET, result0); + assertEquals(Bdd.RESULT_OFFSET + 1, result1); + } + + @Test + void testNegation() { + builder.setConditionCount(2); + + int node = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + int negated = builder.negate(node); + + assertEquals(-node, negated); + assertEquals(node, builder.negate(negated)); // Double negation + + // Cannot negate terminals + assertEquals(-1, builder.negate(builder.makeTrue())); + assertEquals(1, builder.negate(builder.makeFalse())); + } + + @Test + void testNegateResult() { + builder.setConditionCount(2); + int result = builder.makeResult(0); + + assertThrows(IllegalArgumentException.class, () -> builder.negate(result)); + } + + @Test + void testAndOperation() { + builder.setConditionCount(2); + + // TRUE AND TRUE = TRUE + assertEquals(1, builder.and(builder.makeTrue(), builder.makeTrue())); + + // TRUE AND FALSE = FALSE + assertEquals(-1, builder.and(builder.makeTrue(), builder.makeFalse())); + + // FALSE AND x = FALSE + assertEquals(-1, builder.and(builder.makeFalse(), builder.makeTrue())); + assertEquals(-1, builder.and(builder.makeFalse(), builder.makeFalse())); + + // x AND x = x + int node = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + assertEquals(node, builder.and(node, node)); + } + + @Test + void testOrOperation() { + builder.setConditionCount(2); + + // FALSE OR FALSE = FALSE + assertEquals(-1, builder.or(builder.makeFalse(), builder.makeFalse())); + + // TRUE OR x = TRUE + assertEquals(1, builder.or(builder.makeTrue(), builder.makeFalse())); + assertEquals(1, builder.or(builder.makeTrue(), builder.makeTrue())); + + // FALSE OR TRUE = TRUE + assertEquals(1, builder.or(builder.makeFalse(), builder.makeTrue())); + + // x OR x = x + int node = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + assertEquals(node, builder.or(node, node)); + } + + @Test + void testIteBasicCases() { + builder.setConditionCount(2); + + // ITE(TRUE, g, h) = g + int g = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + int h = builder.makeNode(1, builder.makeTrue(), builder.makeFalse()); + assertEquals(g, builder.ite(builder.makeTrue(), g, h)); + + // ITE(FALSE, g, h) = h + assertEquals(h, builder.ite(builder.makeFalse(), g, h)); + + // ITE(f, g, g) = g + int f = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + assertEquals(g, builder.ite(f, g, g)); + } + + @Test + void testIteWithComplement() { + builder.setConditionCount(2); + + int f = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + int g = builder.makeNode(1, builder.makeTrue(), builder.makeFalse()); + int h = builder.makeNode(1, builder.makeFalse(), builder.makeTrue()); + + // ITE with complemented condition should swap branches + int result1 = builder.ite(f, g, h); + int result2 = builder.ite(builder.negate(f), h, g); + assertEquals(result1, result2); + } + + @Test + void testResultInIte() { + builder.setConditionCount(1); + + int cond = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + int result0 = builder.makeResult(0); + int result1 = builder.makeResult(1); + + // ITE with result terminals + int ite = builder.ite(cond, result0, result1); + assertTrue(ite != 0); + + // Condition cannot be a result + assertThrows(IllegalArgumentException.class, + () -> builder.ite(result0, builder.makeTrue(), builder.makeFalse())); + } + + @Test + void testSetConditionCountRequired() { + assertThrows(IllegalStateException.class, () -> builder.makeResult(0)); + } + + @Test + void testGetVariable() { + builder.setConditionCount(3); + + assertEquals(-1, builder.getVariable(builder.makeTrue())); + assertEquals(-1, builder.getVariable(builder.makeFalse())); + + int node = builder.makeNode(1, builder.makeTrue(), builder.makeFalse()); + assertEquals(1, builder.getVariable(node)); + assertEquals(1, builder.getVariable(Math.abs(node))); // Use absolute value for complement + + int result = builder.makeResult(0); + assertEquals(-1, builder.getVariable(result)); // results have no variable + } + + @Test + void testReduceSimpleBdd() { + builder.setConditionCount(3); + + // makeNode already reduces nodes with identical branches + // So we need to create a different scenario for reduction + int a = builder.makeNode(2, builder.makeTrue(), builder.makeFalse()); + int b = builder.makeNode(1, a, builder.makeFalse()); + int root = builder.makeNode(0, b, a); + + int nodesBefore = builder.getNodeCount(); + builder.reduce(root); + + // Structure should be preserved if already optimal + assertEquals(nodesBefore, builder.getNodeCount()); + } + + @Test + void testReduceNoChange() { + builder.setConditionCount(2); + + // Create already-reduced BDD + int right = builder.makeNode(1, builder.makeTrue(), builder.makeFalse()); + int root = builder.makeNode(0, right, builder.makeFalse()); + + int nodesBefore = builder.getNodeCount(); + builder.reduce(root); + + assertEquals(nodesBefore, builder.getNodeCount()); + } + + @Test + void testReduceTerminals() { + // Reducing terminals should return them unchanged + assertEquals(1, builder.reduce(builder.makeTrue())); + assertEquals(-1, builder.reduce(builder.makeFalse())); + } + + @Test + void testReduceResults() { + builder.setConditionCount(1); + int result = builder.makeResult(0); + + // Reducing result nodes should return them unchanged + assertEquals(result, builder.reduce(result)); + } + + @Test + void testReduceWithComplement() { + builder.setConditionCount(3); + + // Create BDD with complement edges + int a = builder.makeNode(2, builder.makeTrue(), builder.makeFalse()); + int b = builder.makeNode(1, a, builder.negate(a)); + int root = builder.makeNode(0, b, builder.makeFalse()); + + int reduced = builder.reduce(root); + + // Now test reducing with complemented root + int complementRoot = builder.negate(root); + int reducedComplement = builder.reduce(complementRoot); + + // The result should be the complement of the reduced root + assertEquals(builder.negate(reduced), reducedComplement); + + // Verify the structure is preserved + assertTrue(builder.getNodeCount() > 1); // More than just terminal + } + + @Test + void testReduceClearsCache() { + builder.setConditionCount(2); + + // Create nodes and perform ITE to populate cache - use only boolean nodes + int a = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + int b = builder.makeNode(1, builder.makeTrue(), builder.makeFalse()); + int ite1 = builder.ite(a, b, builder.makeFalse()); + + builder.reduce(ite1); + + // Cache should be cleared, so same ITE creates new result + // Recreate the nodes since reduce may have changed internal state + a = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + b = builder.makeNode(1, builder.makeTrue(), builder.makeFalse()); + int ite2 = builder.ite(a, b, builder.makeFalse()); + assertTrue(ite2 != 0); // Should get a valid reference + } + + @Test + void testReduceActuallyReduces() { + builder.setConditionCount(3); + + // First create some nodes + int bottom = builder.makeNode(2, builder.makeTrue(), builder.makeFalse()); + int middle = builder.makeNode(1, bottom, builder.makeFalse()); + int root = builder.makeNode(0, middle, bottom); + + int beforeSize = builder.getNodeCount(); + builder.reduce(root); + int afterSize = builder.getNodeCount(); + + // In this case, no reduction should occur since makeNode already optimized + assertEquals(beforeSize, afterSize); + } + + @Test + void testReduceWithPreExistingComplementStructure() { + builder.setConditionCount(3); + + // Create a structure where reduce will encounter complement on low during rebuild + // First create base nodes + int a = builder.makeNode(2, builder.makeTrue(), builder.makeFalse()); + + // When we create this node, makeNode will canonicalize it + // The actual stored node will have the complement bit on the reference, not in the node + int b = builder.makeNode(1, builder.makeTrue(), builder.negate(a)); + + // This creates a scenario where during reduce's rebuild, + // makeNodeInNew might encounter the complement + int root = builder.makeNode(0, b, a); + + // Force a reduce operation + int reduced = builder.reduce(root); + + // The BDD should be functionally equivalent + // We can't make strong assertions about node count since reduce may optimize + assertTrue(reduced != 0); + + // Verify the BDD still evaluates correctly + // by checking that it's not a constant + assertNotEquals(reduced, builder.makeTrue()); + assertNotEquals(reduced, builder.makeFalse()); + } + + @Test + void testCofactorRecursive() { + builder.setConditionCount(3); + + // Create a multi-level BDD with only boolean nodes (no results) + int bottom = builder.makeNode(2, builder.makeTrue(), builder.makeFalse()); + int middle = builder.makeNode(1, bottom, builder.makeFalse()); + int root = builder.makeNode(0, middle, bottom); + + // Cofactor with respect to variable 1 (appears deeper in BDD) + int cofactorTrue = builder.cofactor(root, 1, true); + int cofactorFalse = builder.cofactor(root, 1, false); + + // The cofactors should be different + assertTrue(cofactorTrue != cofactorFalse); + + // Verify structure exists + assertTrue(builder.getNodeCount() > 1); + } + + @Test + void testCofactorWithResults() { + builder.setConditionCount(2); + + // Create BDD with result terminals + int result0 = builder.makeResult(0); + int result1 = builder.makeResult(1); + int node = builder.makeNode(0, result0, result1); + + // Cofactor should select appropriate result + assertEquals(result0, builder.cofactor(node, 0, true)); + assertEquals(result1, builder.cofactor(node, 0, false)); + } + + @Test + void testReduceWithResults() { + builder.setConditionCount(2); + + // Create a BDD that uses results properly + int result0 = builder.makeResult(0); + int result1 = builder.makeResult(1); + + // Create condition nodes that branch to results + int node = builder.makeNode(1, result0, result1); + int root = builder.makeNode(0, node, builder.makeFalse()); + + int reduced = builder.reduce(root); + + // The structure should be preserved + assertNotEquals(0, reduced); // Should not be invalid + assertFalse(builder.isResult(reduced)); // Root should still be a condition node + } + + @Test + void testIteWithResultsInBranches() { + builder.setConditionCount(2); + + // Create a condition and two results + int cond = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + int result0 = builder.makeResult(0); + int result1 = builder.makeResult(1); + + // ITE should handle results in then/else branches + int ite = builder.ite(cond, result0, result1); + + // The result should be a node that branches to the two results + assertTrue(ite > 0); + assertFalse(builder.isResult(ite)); + } + + @Test + void testResultMaskNoCollisions() { + builder.setConditionCount(3); + + // Create many nodes to ensure no collision with result encoding + int node1 = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + int node2 = builder.makeNode(1, node1, builder.makeFalse()); + int node3 = builder.makeNode(2, node2, node1); + + int result0 = builder.makeResult(0); + int result1 = builder.makeResult(1); + + // Verify no collisions + assertNotEquals(node1, result0); + assertNotEquals(node2, result0); + assertNotEquals(node3, result0); + assertNotEquals(node1, result1); + assertNotEquals(node2, result1); + assertNotEquals(node3, result1); + + // Verify correct identification + assertFalse(builder.isResult(node1)); + assertFalse(builder.isResult(node2)); + assertFalse(builder.isResult(node3)); + assertTrue(builder.isResult(result0)); + assertTrue(builder.isResult(result1)); + } + + @Test + void testReset() { + builder.setConditionCount(2); + + // Create some state + builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + builder.makeResult(0); + + builder.reset(); + + assertEquals(1, builder.getNodeCount()); // Only terminal + assertThrows(IllegalStateException.class, () -> builder.makeResult(0)); + + // Can use builder again + builder.setConditionCount(1); + int newNode = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + assertNotEquals(0, newNode); // Should get a valid reference + assertNotEquals(1, Math.abs(newNode)); // Should not be a terminal + assertNotEquals(-1, Math.abs(newNode)); // Should not be a terminal + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java new file mode 100644 index 00000000000..40f7fa09f1a --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java @@ -0,0 +1,251 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Arrays; +import java.util.List; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; + +class BddCompilerTest { + + // Common parameters used across tests + private static final Parameter REGION_PARAM = Parameter.builder() + .name("Region") + .type(ParameterType.STRING) + .build(); + + private static final Parameter BUCKET_PARAM = Parameter.builder() + .name("Bucket") + .type(ParameterType.STRING) + .build(); + + private static final Parameter A_PARAM = Parameter.builder() + .name("A") + .type(ParameterType.STRING) + .build(); + + private static final Parameter B_PARAM = Parameter.builder() + .name("B") + .type(ParameterType.STRING) + .build(); + + @Test + void testCompileSimpleEndpointRule() { + // Single rule with one condition + Rule rule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(REGION_PARAM).build()) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); + + Bdd bdd = compiler.compile(); + + assertNotNull(bdd); + assertEquals(1, bdd.getConditionCount()); + // Results include: NoMatchRule, endpoint, and possibly a terminal for the false branch + assertEquals(3, bdd.getResultCount()); + assertTrue(bdd.getRootRef() != 0); + } + + @Test + void testCompileErrorRule() { + // Error rule instead of endpoint + Rule rule = ErrorRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Bucket")).build()) + .error("Bucket is required"); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(BUCKET_PARAM).build()) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); + + Bdd bdd = compiler.compile(); + + assertEquals(1, bdd.getConditionCount()); + // Results: NoMatchRule, error, and possibly a terminal + assertEquals(3, bdd.getResultCount()); + } + + @Test + void testCompileTreeRule() { + // Nested tree rule + Rule nestedRule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Bucket")).build()) + .endpoint(TestHelpers.endpoint("https://bucket.example.com")); + Rule treeRule = TreeRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .treeRule(nestedRule); + Parameters params = Parameters.builder().addParameter(REGION_PARAM).addParameter(BUCKET_PARAM).build(); + EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(params).addRule(treeRule).build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); + + Bdd bdd = compiler.compile(); + + assertEquals(2, bdd.getConditionCount()); + assertTrue(bdd.getNodeCount() >= 2); // Should have multiple nodes + } + + @Test + void testCompileWithCustomOrdering() { + // Multiple conditions to test ordering + Rule rule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("A")).build(), + Condition.builder().fn(TestHelpers.isSet("B")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + Parameters params = Parameters.builder().addParameter(A_PARAM).addParameter(B_PARAM).build(); + EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(params).addRule(rule).build(); + + Cfg cfg = Cfg.from(ruleSet); + + // Get the actual conditions from the CFG after SSA transform + Condition[] cfgConditions = cfg.getConditions(); + + // Find the conditions that correspond to A and B + Condition condA = null; + Condition condB = null; + for (Condition c : cfgConditions) { + String condStr = c.toString(); + if (condStr.contains("isSet(A)")) { + condA = c; + } else if (condStr.contains("isSet(B)")) { + condB = c; + } + } + + assertNotNull(condA, "Could not find condition for A"); + assertNotNull(condB, "Could not find condition for B"); + + // Use fixed ordering (B before A) + OrderingStrategy customOrdering = OrderingStrategy.fixed(Arrays.asList(condB, condA)); + + BddCompiler compiler = new BddCompiler(cfg, customOrdering, new BddBuilder()); + Bdd bdd = compiler.compile(); + + List orderedConditions = compiler.getOrderedConditions(); + + // Verify ordering was applied + assertEquals(2, bdd.getConditionCount()); + assertEquals(2, orderedConditions.size()); + // Verify B comes before A in the ordering + assertEquals(condB, orderedConditions.get(0)); + assertEquals(condA, orderedConditions.get(1)); + } + + @Test + void testCompileEmptyRuleSet() { + // No rules + EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(Parameters.builder().build()).build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); + Bdd bdd = compiler.compile(); + + assertEquals(0, bdd.getConditionCount()); + // Even with no rules, there's the NoMatchRule and possibly a terminal + assertEquals(2, bdd.getResultCount()); + // Should have at least terminal node + assertEquals(1, bdd.getNodeCount()); + } + + @Test + void testCompileSameResultMultiplePaths() { + // Two rules leading to same endpoint + Rule rule1 = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Rule rule2 = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Bucket")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(REGION_PARAM) + .addParameter(BUCKET_PARAM) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule1) + .addRule(rule2) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); + + Bdd bdd = compiler.compile(); + List results = compiler.getIndexedResults(); + + // Should have 2 conditions + assertEquals(2, bdd.getConditionCount()); + // Results: NoMatchRule at index 0, plus the endpoint(s) + // The compiler may deduplicate identical endpoints or keep them separate + assertTrue(bdd.getResultCount() >= 2); + assertTrue(bdd.getResultCount() <= 3); + + // Verify NoMatchRule is always at index 0 + assertEquals("NoMatchRule", results.get(0).getClass().getSimpleName()); + } + + @Test + void testCompileWithReduction() { + // Test that the BDD is properly reduced after compilation + Rule rule = EndpointRule.builder() + .conditions( + Condition.builder().fn(TestHelpers.isSet("Region")).build(), + Condition.builder().fn(TestHelpers.isSet("Bucket")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(REGION_PARAM) + .addParameter(BUCKET_PARAM) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + BddBuilder builder = new BddBuilder(); + BddCompiler compiler = new BddCompiler(cfg, builder); + + Bdd bdd = compiler.compile(); + + // The BDD should be reduced (no redundant nodes) + assertNotNull(bdd); + assertEquals(2, bdd.getConditionCount()); + + // After reduction, we should have a minimal BDD + // For 2 conditions with AND semantics leading to one endpoint: + // We expect approximately 3-4 nodes (depending on the exact structure) + assertTrue(bdd.getNodeCount() <= 5, "BDD should be reduced to minimal form"); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceCheckerTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceCheckerTest.java new file mode 100644 index 00000000000..458f696bdf6 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceCheckerTest.java @@ -0,0 +1,213 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.utils.ListUtils; + +class BddEquivalenceCheckerTest { + + @Test + void testSimpleEquivalentBdd() { + // Create a simple ruleset + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build()) + .rules(ListUtils.of( + EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")), + // Default case + ErrorRule.builder().error(Literal.of("No region provided")))) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); + Bdd bdd = compiler.compile(); + + BddEquivalenceChecker checker = BddEquivalenceChecker.of( + cfg, + bdd, + compiler.getOrderedConditions(), + compiler.getIndexedResults()); + + assertDoesNotThrow(checker::verify); + } + + @Test + void testEmptyRulesetEquivalence() { + // Empty ruleset with a default endpoint + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().build()) + .rules(ListUtils.of(EndpointRule.builder().endpoint(TestHelpers.endpoint("https://default.com")))) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); + Bdd bdd = compiler.compile(); + + BddEquivalenceChecker checker = BddEquivalenceChecker.of( + cfg, + bdd, + compiler.getOrderedConditions(), + compiler.getIndexedResults()); + + assertDoesNotThrow(checker::verify); + } + + @Test + void testMultipleConditionsEquivalence() { + // Ruleset with multiple conditions (AND logic) + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Bucket").type(ParameterType.STRING).build()) + .build()) + .rules(ListUtils.of( + EndpointRule.builder() + .conditions( + Condition.builder().fn(TestHelpers.isSet("Region")).build(), + Condition.builder().fn(TestHelpers.isSet("Bucket")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")), + ErrorRule.builder().error(Literal.of("Missing required parameters")))) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); + Bdd bdd = compiler.compile(); + + BddEquivalenceChecker checker = BddEquivalenceChecker.of( + cfg, + bdd, + compiler.getOrderedConditions(), + compiler.getIndexedResults()); + + assertDoesNotThrow(checker::verify); + } + + @Test + void testSetMaxSamples() { + // Create a simpler test with just 3 parameters to avoid ordering issues + Parameters.Builder paramsBuilder = Parameters.builder(); + List rules = new ArrayList<>(); + + // Add parameters with zero-padded names to ensure correct ordering + for (int i = 0; i < 3; i++) { + String paramName = String.format("Param%02d", i); // Param00, Param01, Param02 + paramsBuilder.addParameter(Parameter.builder().name(paramName).type(ParameterType.STRING).build()); + rules.add(EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet(paramName)).build()) + .endpoint(TestHelpers.endpoint("https://example" + i + ".com"))); + } + + // default case + rules.add(ErrorRule.builder().error(Literal.of("No parameters set"))); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(paramsBuilder.build()) + .rules(rules) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); + Bdd bdd = compiler.compile(); + + BddEquivalenceChecker checker = BddEquivalenceChecker.of( + cfg, + bdd, + compiler.getOrderedConditions(), + compiler.getIndexedResults()); + + // Set a small max samples to make test fast + checker.setMaxSamples(100); + + assertDoesNotThrow(checker::verify); + } + + @Test + void testSetMaxDuration() { + // Create a complex ruleset + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build()) + .rules(ListUtils.of( + EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")), + ErrorRule.builder().error(Literal.of("No region provided")))) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); + Bdd bdd = compiler.compile(); + + BddEquivalenceChecker checker = BddEquivalenceChecker.of( + cfg, + bdd, + compiler.getOrderedConditions(), + compiler.getIndexedResults()); + + // Set a short timeout + checker.setMaxDuration(Duration.ofMillis(100)); + + assertDoesNotThrow(checker::verify); + } + + @Test + void testLargeNumberOfConditions() { + // Test with 25 conditions to ensure it uses sampling rather than exhaustive testing + Parameters.Builder paramsBuilder = Parameters.builder(); + List conditions = new ArrayList<>(); + + for (int i = 0; i < 25; i++) { + String paramName = String.format("Param%02d", i); + paramsBuilder.addParameter(Parameter.builder().name(paramName).type(ParameterType.STRING).build()); + conditions.add(Condition.builder().fn(TestHelpers.isSet(paramName)).build()); + } + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(paramsBuilder.build()) + .rules(ListUtils.of( + EndpointRule.builder() + .conditions(conditions) + .endpoint(TestHelpers.endpoint("https://example.com")), + ErrorRule.builder().error(Literal.of("Not all parameters set")))) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); + Bdd bdd = compiler.compile(); + + BddEquivalenceChecker checker = BddEquivalenceChecker.of( + cfg, + bdd, + compiler.getOrderedConditions(), + compiler.getIndexedResults()); + + // Set reasonable limits for large condition sets + checker.setMaxSamples(10000); + checker.setMaxDuration(Duration.ofSeconds(5)); + + assertDoesNotThrow(checker::verify); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java new file mode 100644 index 00000000000..a0c2483052d --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java @@ -0,0 +1,372 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; + +class BddTest { + + @Test + void testConstructorValidation() { + // Should reject complemented root (except -1 which is FALSE terminal) + assertThrows(IllegalArgumentException.class, + () -> new Bdd(-2, 0, 0, 1, consumer -> consumer.accept(-1, 1, -1))); + + // Should accept positive root + Bdd bdd = new Bdd(1, 0, 0, 1, consumer -> consumer.accept(-1, 1, -1)); + assertEquals(1, bdd.getRootRef()); + + // Should accept FALSE terminal as root + Bdd bdd2 = new Bdd(-1, 0, 0, 1, consumer -> consumer.accept(-1, 1, -1)); + assertEquals(-1, bdd2.getRootRef()); + } + + @Test + void testBasicAccessors() { + Bdd bdd = new Bdd(2, 2, 1, 3, consumer -> { + consumer.accept(-1, 1, -1); // node 0: terminal + consumer.accept(0, 3, -1); // node 1: var 0, high 3, low -1 + consumer.accept(1, 1, -1); // node 2: var 1, high 1, low -1 + }); + + assertEquals(2, bdd.getConditionCount()); + assertEquals(1, bdd.getResultCount()); + assertEquals(3, bdd.getNodeCount()); + assertEquals(2, bdd.getRootRef()); + + // Test node accessors + assertEquals(-1, bdd.getVariable(0)); + assertEquals(1, bdd.getHigh(0)); + assertEquals(-1, bdd.getLow(0)); + + assertEquals(0, bdd.getVariable(1)); + assertEquals(3, bdd.getHigh(1)); + assertEquals(-1, bdd.getLow(1)); + } + + @Test + void testFromRuleSet() { + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build()) + .addRule(EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com"))) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + Bdd bdd = new BddCompiler(cfg, OrderingStrategy.initialOrdering(cfg), new BddBuilder()).compile(); + + assertTrue(bdd.getConditionCount() > 0); + assertTrue(bdd.getResultCount() > 0); + assertTrue(bdd.getNodeCount() > 1); // At least terminal + one node + } + + @Test + void testFromCfg() { + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().build()) + .addRule(ErrorRule.builder().error("test error")) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + Bdd bdd = new BddCompiler(cfg, OrderingStrategy.initialOrdering(cfg), new BddBuilder()).compile(); + + assertEquals(0, bdd.getConditionCount()); // No conditions + assertTrue(bdd.getResultCount() > 0); + } + + @Test + void testToString() { + Bdd bdd = createSimpleBdd(); + String str = bdd.toString(); + + assertTrue(str.contains("Bdd {")); + assertTrue(str.contains("conditions:")); + assertTrue(str.contains("results:")); + assertTrue(str.contains("root:")); + assertTrue(str.contains("nodes")); + } + + @Test + void testToStringWithDifferentNodeTypes() { + // BDD structure referencing the correct indices + Bdd bdd = new Bdd(2, 2, 3, 3, consumer -> { + consumer.accept(-1, 1, -1); // node 0: terminal + consumer.accept(0, 2, -1); // node 1: if Region is set, go to node 2, else FALSE + consumer.accept(1, Bdd.RESULT_OFFSET + 2, Bdd.RESULT_OFFSET + 1); // node 2: if UseFips, return result 2, else result 1 + }); + + String str = bdd.toString(); + + assertTrue(str.contains("conditions: 2")); + assertTrue(str.contains("results: 3")); + assertTrue(str.contains("C0")); + assertTrue(str.contains("C1")); + assertTrue(str.contains("R1")); + assertTrue(str.contains("R2")); + } + + @Test + void testReferenceHelperMethods() { + // Test isNodeReference + assertTrue(Bdd.isNodeReference(2)); + assertTrue(Bdd.isNodeReference(-2)); + assertFalse(Bdd.isNodeReference(0)); + assertFalse(Bdd.isNodeReference(1)); + assertFalse(Bdd.isNodeReference(-1)); + assertFalse(Bdd.isNodeReference(Bdd.RESULT_OFFSET)); + + // Test isResultReference + assertTrue(Bdd.isResultReference(Bdd.RESULT_OFFSET)); + assertTrue(Bdd.isResultReference(Bdd.RESULT_OFFSET + 1)); + assertFalse(Bdd.isResultReference(1)); + assertFalse(Bdd.isResultReference(-1)); + + // Test isTerminal + assertTrue(Bdd.isTerminal(1)); + assertTrue(Bdd.isTerminal(-1)); + assertFalse(Bdd.isTerminal(2)); + assertFalse(Bdd.isTerminal(Bdd.RESULT_OFFSET)); + + // Test isComplemented + assertTrue(Bdd.isComplemented(-2)); + assertTrue(Bdd.isComplemented(-3)); + assertFalse(Bdd.isComplemented(-1)); // FALSE terminal is not considered complemented + assertFalse(Bdd.isComplemented(1)); + assertFalse(Bdd.isComplemented(2)); + } + + private Bdd createSimpleBdd() { + return new Bdd(2, 1, 2, 2, consumer -> { + consumer.accept(-1, 1, -1); // node 0: terminal + consumer.accept(0, Bdd.RESULT_OFFSET + 1, -1); // node 1: if cond true, return result 1, else FALSE + }); + } + + @Test + void testStreamingConstructorValidation() { + // Valid construction + assertDoesNotThrow(() -> { + new Bdd(1, 1, 1, 1, consumer -> { + consumer.accept(-1, 1, -1); + }); + }); + + // Root cannot be complemented (except -1) + assertThrows(IllegalArgumentException.class, () -> { + new Bdd(-2, 1, 1, 1, consumer -> { + consumer.accept(-1, 1, -1); + }); + }); + + // Root -1 (FALSE) is allowed + assertDoesNotThrow(() -> { + new Bdd(-1, 1, 1, 1, consumer -> { + consumer.accept(-1, 1, -1); + }); + }); + + // Wrong node count + assertThrows(IllegalStateException.class, () -> { + new Bdd(1, 1, 1, 2, consumer -> { + consumer.accept(-1, 1, -1); // Only provides 1 node, but claims 2 + }); + }); + } + + @Test + void testArrayConstructorValidation() { + int[] nodes = {-1, 1, -1}; + + // Valid construction + assertDoesNotThrow(() -> { + new Bdd(1, 1, 1, 1, nodes); + }); + + // Wrong array length (not multiple of 3) + int[] wrongLength = {-1, 1, -1, 0}; // 4 elements, not divisible by 3 + assertThrows(IllegalArgumentException.class, () -> { + new Bdd(1, 1, 1, 1, wrongLength); + }); + + // Array length doesn't match nodeCount + assertThrows(IllegalArgumentException.class, () -> { + new Bdd(1, 1, 1, 2, nodes); // nodeCount=2 but array has 3 elements (1 node) + }); + + // Root cannot be complemented (except -1) + assertThrows(IllegalArgumentException.class, () -> { + new Bdd(-2, 1, 1, 1, nodes); + }); + + // Root -1 (FALSE) is allowed + assertDoesNotThrow(() -> { + new Bdd(-1, 1, 1, 1, nodes); + }); + } + + @Test + void testGetterBoundsChecking() { + Bdd bdd = new Bdd(1, 1, 1, 2, consumer -> { + consumer.accept(-1, 1, -1); + consumer.accept(0, 1, -1); + }); + + // Valid indices + assertDoesNotThrow(() -> bdd.getVariable(0)); + assertDoesNotThrow(() -> bdd.getVariable(1)); + + // Out of bounds + assertThrows(IndexOutOfBoundsException.class, () -> bdd.getVariable(-1)); + assertThrows(IndexOutOfBoundsException.class, () -> bdd.getVariable(2)); + + // Same for high/low + assertThrows(IndexOutOfBoundsException.class, () -> bdd.getHigh(-1)); + assertThrows(IndexOutOfBoundsException.class, () -> bdd.getHigh(2)); + assertThrows(IndexOutOfBoundsException.class, () -> bdd.getLow(-1)); + assertThrows(IndexOutOfBoundsException.class, () -> bdd.getLow(2)); + } + + @Test + void testEquals() { + // Create two identical BDDs + Bdd bdd1 = new Bdd(2, 1, 1, 2, consumer -> { + consumer.accept(-1, 1, -1); + consumer.accept(0, 1, -1); + }); + + Bdd bdd2 = new Bdd(2, 1, 1, 2, consumer -> { + consumer.accept(-1, 1, -1); + consumer.accept(0, 1, -1); + }); + + // Same content should be equal + assertEquals(bdd1, bdd2); + assertEquals(bdd1.hashCode(), bdd2.hashCode()); + + // Self equality + assertEquals(bdd1, bdd1); + + // Different root ref (use TRUE terminal) + Bdd bdd3 = new Bdd(1, 1, 1, 2, consumer -> { + consumer.accept(-1, 1, -1); + consumer.accept(0, 1, -1); + }); + assertNotEquals(bdd1, bdd3); + + // Different root ref (use FALSE terminal) + Bdd bdd4 = new Bdd(-1, 1, 1, 2, consumer -> { + consumer.accept(-1, 1, -1); + consumer.accept(0, 1, -1); + }); + assertNotEquals(bdd1, bdd4); + + // Different node count + Bdd bdd5 = new Bdd(2, 1, 1, 3, consumer -> { + consumer.accept(-1, 1, -1); + consumer.accept(0, 1, -1); + consumer.accept(0, -1, 1); + }); + assertNotEquals(bdd1, bdd5); + + // Different node content + Bdd bdd6 = new Bdd(2, 1, 1, 2, consumer -> { + consumer.accept(-1, 1, -1); + consumer.accept(0, -1, 1); // Different high/low + }); + assertNotEquals(bdd1, bdd6); + + // Different root ref (use result reference) + Bdd bdd7 = new Bdd(Bdd.RESULT_OFFSET, 1, 1, 2, consumer -> { + consumer.accept(-1, 1, -1); + consumer.accept(0, 1, -1); + }); + assertNotEquals(bdd1, bdd7); + + // Null and different type + assertNotEquals(bdd1, null); + assertNotEquals(bdd1, "not a BDD"); + } + + // Used to regenerate BDD test cases for errorfiles + // @Test + // void generateValidBddEncoding() { + // Parameter region = Parameter.builder() + // .name("Region") + // .type(ParameterType.STRING) + // .required(true) + // .documentation("The AWS region") + // .build(); + // + // Parameter useFips = Parameter.builder() + // .name("UseFips") + // .type(ParameterType.BOOLEAN) + // .required(true) + // .defaultValue(software.amazon.smithy.rulesengine.language.evaluation.value.Value.booleanValue(false)) + // .documentation("Use FIPS endpoints") + // .build(); + // + // Parameters params = Parameters.builder() + // .addParameter(region) + // .addParameter(useFips) + // .build(); + // + // Condition useFipsTrue = Condition.builder() + // .fn(BooleanEquals.ofExpressions( + // Expression.getReference(Identifier.of("UseFips")), + // Expression.of(true))) + // .build(); + // + // // Create endpoints + // Endpoint normalEndpoint = Endpoint.builder() + // .url(Expression.of("https://service.{Region}.amazonaws.com")) + // .build(); + // + // Endpoint fipsEndpoint = Endpoint.builder() + // .url(Expression.of("https://service-fips.{Region}.amazonaws.com")) + // .build(); + // + // Rule fipsRule = EndpointRule.builder() + // .condition(useFipsTrue) + // .endpoint(fipsEndpoint); + // + // Rule normalRule = EndpointRule.builder() + // .endpoint(normalEndpoint); + // + // EndpointRuleSet ruleSet = EndpointRuleSet.builder() + // .parameters(params) + // .rules(Arrays.asList(fipsRule, normalRule)) + // .build(); + // + // Cfg cfg = Cfg.from(ruleSet); + // + // BddTrait trait = BddTrait.from(cfg); + // BddTraitValidator validator = new BddTraitValidator(); + // ServiceShape service = ServiceShape.builder().id("foo#Bar").addTrait(trait).build(); + // Model model = Model.builder().addShape(service).build(); + // System.out.println(validator.validate(model)); + // + // // Get the base64 encoded nodes + // System.out.println(Node.prettyPrintJson(trait.toNode())); + // } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/CfgConeAnalysisTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/CfgConeAnalysisTest.java new file mode 100644 index 00000000000..df1c9a6ec80 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/CfgConeAnalysisTest.java @@ -0,0 +1,233 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.HashMap; +import java.util.Map; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.utils.ListUtils; + +class CfgConeAnalysisTest { + + @Test + void testSingleConditionSingleResult() { + // Simple rule: if Region is set, return endpoint + Rule rule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build()) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + Condition[] conditions = cfg.getConditions(); + Map conditionToIndex = new HashMap<>(); + for (int i = 0; i < conditions.length; i++) { + conditionToIndex.put(conditions[i], i); + } + + CfgConeAnalysis analysis = new CfgConeAnalysis(cfg, conditions, conditionToIndex); + + // The single condition should have dominator depth 0 (at root) + assertEquals(0, analysis.dominatorDepth(0)); + // Should reach 2 result nodes (the endpoint and the terminal/no-match) + assertEquals(2, analysis.coneSize(0)); + } + + @Test + void testChainedConditions() { + // Rule with two conditions in sequence (AND logic) + Rule rule = EndpointRule.builder() + .conditions( + Condition.builder().fn(TestHelpers.isSet("Region")).build(), + Condition.builder().fn(TestHelpers.isSet("Bucket")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Bucket").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + Condition[] conditions = cfg.getConditions(); + Map conditionToIndex = new HashMap<>(); + for (int i = 0; i < conditions.length; i++) { + conditionToIndex.put(conditions[i], i); + } + + CfgConeAnalysis analysis = new CfgConeAnalysis(cfg, conditions, conditionToIndex); + + // First condition should be at depth 0 + assertEquals(0, analysis.dominatorDepth(0)); + // Second condition should be at depth 1 (one edge from root) + assertEquals(1, analysis.dominatorDepth(1)); + + // Both conditions lead to the same result + assertTrue(analysis.coneSize(0) >= 1); + assertTrue(analysis.coneSize(1) >= 1); + } + + @Test + void testMultipleBranches() { + // Two separate rules leading to different endpoints + Rule rule1 = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://regional.com")); + + Rule rule2 = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Bucket")).build()) + .endpoint(TestHelpers.endpoint("https://bucket.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Bucket").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .rules(ListUtils.of(rule1, rule2)) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + Condition[] conditions = cfg.getConditions(); + Map conditionToIndex = new HashMap<>(); + for (int i = 0; i < conditions.length; i++) { + conditionToIndex.put(conditions[i], i); + } + + CfgConeAnalysis analysis = new CfgConeAnalysis(cfg, conditions, conditionToIndex); + + // Both conditions should be at the root level (depth 0 or 1) + assertTrue(analysis.dominatorDepth(0) <= 1); + assertTrue(analysis.dominatorDepth(1) <= 1); + + // Each condition reaches at least one result + assertTrue(analysis.coneSize(0) >= 1); + assertTrue(analysis.coneSize(1) >= 1); + } + + @Test + void testNestedTreeRule() { + // Tree rule with nested structure + Rule innerRule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Bucket")).build()) + .endpoint(TestHelpers.endpoint("https://bucket.com")); + + Rule treeRule = TreeRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .treeRule(innerRule); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Bucket").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(treeRule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + Condition[] conditions = cfg.getConditions(); + Map conditionToIndex = new HashMap<>(); + for (int i = 0; i < conditions.length; i++) { + conditionToIndex.put(conditions[i], i); + } + + CfgConeAnalysis analysis = new CfgConeAnalysis(cfg, conditions, conditionToIndex); + + // Region condition should be at root (depth 0) + int regionIdx = -1; + int bucketIdx = -1; + for (int i = 0; i < conditions.length; i++) { + if (conditions[i].toString().contains("Region")) { + regionIdx = i; + } else if (conditions[i].toString().contains("Bucket")) { + bucketIdx = i; + } + } + + if (regionIdx >= 0) { + assertEquals(0, analysis.dominatorDepth(regionIdx)); + } + + // Bucket condition should be deeper (at least depth 1) + if (bucketIdx >= 0) { + assertTrue(analysis.dominatorDepth(bucketIdx) >= 1); + } + } + + @Test + void testErrorRule() { + // Rule that returns an error instead of endpoint + Rule rule = ErrorRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("InvalidParam")).build()) + .error("Invalid parameter provided"); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder() + .addParameter(Parameter.builder().name("InvalidParam").type(ParameterType.STRING).build()) + .build()) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + Condition[] conditions = cfg.getConditions(); + Map conditionToIndex = new HashMap<>(); + for (int i = 0; i < conditions.length; i++) { + conditionToIndex.put(conditions[i], i); + } + + CfgConeAnalysis analysis = new CfgConeAnalysis(cfg, conditions, conditionToIndex); + + // The condition should be at root + assertEquals(0, analysis.dominatorDepth(0)); + // Should reach 2 results (the error and terminal/no-match) + assertEquals(2, analysis.coneSize(0)); + } + + @Test + void testEmptyCfg() { + // Empty ruleset + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().build()) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + Condition[] conditions = cfg.getConditions(); + + // Should have no conditions + assertEquals(0, conditions.length); + + // Analysis should handle empty CFG gracefully + Map conditionToIndex = new HashMap<>(); + CfgConeAnalysis analysis = new CfgConeAnalysis(cfg, conditions, conditionToIndex); + + // No assertions needed - just verify it doesn't throw + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/InitialOrderingTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/InitialOrderingTest.java new file mode 100644 index 00000000000..bddb4a406ea --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/InitialOrderingTest.java @@ -0,0 +1,335 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.utils.ListUtils; + +class InitialOrderingTest { + + @Test + void testSimpleOrdering() { + // Single rule with one condition + Rule rule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build()) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + InitialOrdering ordering = new InitialOrdering(cfg); + + List ordered = ordering.orderConditions(cfg.getConditions()); + + assertNotNull(ordered); + assertEquals(1, ordered.size()); + assertTrue(ordered.get(0).toString().contains("Region")); + } + + @Test + void testDependencyOrdering() { + // Rule with variable dependencies: x = isSet(A), then use x + Condition defineX = Condition.builder() + .fn(TestHelpers.isSet("A")) + .result(Identifier.of("x")) + .build(); + + Condition useX = Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("x")), + Literal.of(true))) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(defineX, useX) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("A").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + + // After SSA and coalesce transforms, there might be only one condition + Condition[] conditions = cfg.getConditions(); + + // If coalesce merged them, we'll have 1 condition. Otherwise 2. + assertTrue(conditions.length >= 1 && conditions.length <= 2); + + InitialOrdering ordering = new InitialOrdering(cfg); + List ordered = ordering.orderConditions(conditions); + + assertEquals(conditions.length, ordered.size()); + + // If we still have 2 conditions, verify dependency order + if (ordered.size() == 2) { + // Find which condition defines x and which uses it + int defineIndex = -1; + int useIndex = -1; + for (int i = 0; i < ordered.size(); i++) { + if (ordered.get(i).getResult().isPresent() && + ordered.get(i).getResult().get().toString().equals("x")) { + defineIndex = i; + } else if (ordered.get(i).toString().contains("x")) { + useIndex = i; + } + } + + // Define must come before use (if both exist) + if (defineIndex >= 0 && useIndex >= 0) { + assertTrue(defineIndex < useIndex, "Definition of x must come before its use"); + } + } + } + + @Test + void testGateConditionPriority() { + // Create a gate condition (isSet) that multiple other conditions depend on + Condition gate = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + // Multiple conditions that use hasRegion + Condition branch1 = Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasRegion")), + Literal.of(true))) + .build(); + + Condition branch2 = Condition.builder() + .fn(TestHelpers.isSet("Bucket")) + .build(); + + Rule rule1 = EndpointRule.builder() + .conditions(gate, branch1) + .endpoint(TestHelpers.endpoint("https://example1.com")); + + Rule rule2 = EndpointRule.builder() + .conditions(gate, branch2) + .endpoint(TestHelpers.endpoint("https://example2.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Bucket").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .rules(ListUtils.of(rule1, rule2)) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + InitialOrdering ordering = new InitialOrdering(cfg); + + List ordered = ordering.orderConditions(cfg.getConditions()); + + // Gate condition should be ordered early since multiple branches depend on it + assertNotNull(ordered); + assertTrue(ordered.size() >= 2); + + // Find the gate condition + int gateIndex = -1; + for (int i = 0; i < ordered.size(); i++) { + if (ordered.get(i).getResult().isPresent() && + ordered.get(i).getResult().get().toString().equals("hasRegion")) { + gateIndex = i; + break; + } + } + + // Gate should be ordered before conditions that depend on it + assertTrue(gateIndex >= 0, "Gate condition should be in the ordering"); + assertTrue(gateIndex < ordered.size() - 1, "Gate should not be last"); + } + + @Test + void testNestedTreeOrdering() { + // Nested tree structure to test depth-based ordering + Rule innerRule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Bucket")).build()) + .endpoint(TestHelpers.endpoint("https://bucket.com")); + + Rule treeRule = TreeRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .treeRule(innerRule); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Bucket").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(treeRule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + InitialOrdering ordering = new InitialOrdering(cfg); + + List ordered = ordering.orderConditions(cfg.getConditions()); + + assertEquals(2, ordered.size()); + + // Region should come before Bucket due to tree structure + int regionIndex = -1; + int bucketIndex = -1; + for (int i = 0; i < ordered.size(); i++) { + if (ordered.get(i).toString().contains("Region")) { + regionIndex = i; + } else if (ordered.get(i).toString().contains("Bucket")) { + bucketIndex = i; + } + } + + assertTrue(regionIndex < bucketIndex, "Region should be ordered before Bucket"); + } + + @Test + void testMultipleIndependentConditions() { + // Multiple conditions with no dependencies + Rule rule = EndpointRule.builder() + .conditions( + Condition.builder().fn(TestHelpers.isSet("A")).build(), + Condition.builder().fn(TestHelpers.isSet("B")).build(), + Condition.builder().fn(TestHelpers.isSet("C")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("A").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("B").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("C").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + InitialOrdering ordering = new InitialOrdering(cfg); + + List ordered = ordering.orderConditions(cfg.getConditions()); + + // Should order all conditions + assertEquals(3, ordered.size()); + + // Order should be deterministic (based on CFG structure). Run twice to ensure consistency. + List ordered2 = ordering.orderConditions(cfg.getConditions()); + assertEquals(ordered, ordered2, "Ordering should be deterministic"); + } + + @Test + void testEmptyConditions() { + // Ruleset with no conditions + Rule rule = EndpointRule.builder() + .endpoint(TestHelpers.endpoint("https://default.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().build()) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + InitialOrdering ordering = new InitialOrdering(cfg); + + List ordered = ordering.orderConditions(cfg.getConditions()); + + assertNotNull(ordered); + assertEquals(0, ordered.size()); + } + + @Test + void testIsSetGatePriority() { + // Test that isSet conditions used by multiple consumers get priority + Condition isSetGate = Condition.builder() + .fn(IsSet.ofExpressions(Expression.getReference(Identifier.of("Input")))) + .result(Identifier.of("hasInput")) + .build(); + + // Multiple rules use the hasInput variable + Rule rule1 = EndpointRule.builder() + .conditions( + isSetGate, + Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasInput")), + Literal.of(true))) + .build()) + .endpoint(TestHelpers.endpoint("https://example1.com")); + + Rule rule2 = EndpointRule.builder() + .conditions( + isSetGate, + Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasInput")), + Literal.of(false))) + .build()) + .endpoint(TestHelpers.endpoint("https://example2.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Input").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .rules(ListUtils.of(rule1, rule2)) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + InitialOrdering ordering = new InitialOrdering(cfg); + + List ordered = ordering.orderConditions(cfg.getConditions()); + + // The isSet gate should be ordered early + assertNotNull(ordered); + + // After transforms, the structure might change + // Just verify that if an isSet condition exists, it's prioritized + boolean hasIsSet = false; + for (Condition c : ordered) { + if (c.getFunction().getFunctionDefinition() == IsSet.getDefinition()) { + hasIsSet = true; + break; + } + } + + // The test's purpose is to verify prioritization works, not specific positions + // After coalesce transform, the isSet might be merged into other conditions + assertFalse(ordered.isEmpty(), "Should have at least one condition"); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversalTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversalTest.java new file mode 100644 index 00000000000..835d45b3363 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversalTest.java @@ -0,0 +1,204 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; + +import java.util.ArrayList; +import java.util.Collections; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; + +class NodeReversalTest { + + @Test + void testSingleNodeBdd() { + // BDD with just terminal node + Bdd original = new Bdd(1, 0, 0, 1, consumer -> { + consumer.accept(-1, 1, -1); // terminal + }); + + Bdd reversed = NodeReversal.reverse(original); + + // Should be unchanged (only 1 node, reversal returns as-is for <= 2 nodes). + assertEquals(1, reversed.getNodeCount()); + assertEquals(1, reversed.getRootRef()); + + // Check terminal node + assertEquals(-1, reversed.getVariable(0)); + assertEquals(1, reversed.getHigh(0)); + assertEquals(-1, reversed.getLow(0)); + } + + @Test + void testComplementEdges() { + // BDD with complement edges + Bdd original = new Bdd(2, 2, 0, 3, consumer -> { + consumer.accept(-1, 1, -1); // node 0: terminal + consumer.accept(0, 3, -2); // node 1: condition 0, high to node 2, low to complement of node 1 + consumer.accept(1, 1, -1); // node 2: condition 1 + }); + + Bdd reversed = NodeReversal.reverse(original); + + // Mapping: 0->0, 1->2, 2->1 + // Ref mapping: 2->3, 3->2, -2->-3 + assertEquals(3, reversed.getRootRef()); + + // Check complement edge is properly remapped + // Original node 1 is now at index 2 + assertEquals(0, reversed.getVariable(2)); // condition index unchanged + assertEquals(2, reversed.getHigh(2)); // high ref 3 -> 2 + assertEquals(-3, reversed.getLow(2)); // complement low ref -2 -> -3 + } + + @Test + void testResultNodes() { + // BDD with result terminals + Bdd original = new Bdd(2, 4, 2, 4, consumer -> { + consumer.accept(-1, 1, -1); // node 0: terminal + consumer.accept(0, Bdd.RESULT_OFFSET + 1, Bdd.RESULT_OFFSET); // node 1: condition 0 + consumer.accept(2, 1, -1); // node 2: result 0 + consumer.accept(3, 1, -1); // node 3: result 1 + }); + + Bdd reversed = NodeReversal.reverse(original); + + assertEquals(4, reversed.getNodeCount()); + assertEquals(4, reversed.getRootRef()); // root was ref 2, now ref 4 + + // Terminal stays at 0 + assertEquals(-1, reversed.getVariable(0)); + assertEquals(1, reversed.getHigh(0)); + assertEquals(-1, reversed.getLow(0)); + + // Original node 3 now at index 1 + assertEquals(3, reversed.getVariable(1)); + assertEquals(1, reversed.getHigh(1)); + assertEquals(-1, reversed.getLow(1)); + + // Original node 2 stays at index 2 + assertEquals(2, reversed.getVariable(2)); + assertEquals(1, reversed.getHigh(2)); + assertEquals(-1, reversed.getLow(2)); + + // Original node 1 now at index 3 + assertEquals(0, reversed.getVariable(3)); // condition index unchanged + assertEquals(Bdd.RESULT_OFFSET + 1, reversed.getHigh(3)); // result references unchanged + assertEquals(Bdd.RESULT_OFFSET, reversed.getLow(3)); // result references unchanged + } + + @Test + void testFourNodeExample() { + // Simple 4-node example to verify reference mapping + Bdd original = new Bdd(2, 3, 0, 4, consumer -> { + consumer.accept(-1, 1, -1); // node 0: terminal + consumer.accept(0, 3, 4); // node 1: points to nodes 2 and 3 + consumer.accept(1, 1, -1); // node 2: + consumer.accept(2, 1, -1); // node 3: + }); + + Bdd reversed = NodeReversal.reverse(original); + + // Mapping: 0->0, 1->3, 2->2, 3->1 + // Ref mapping: 2->4, 3->3, 4->2 + assertEquals(4, reversed.getRootRef()); // root ref 2 -> 4 + + // Check node at index 3 (originally node 1) + assertEquals(0, reversed.getVariable(3)); // original node 1's variable + assertEquals(3, reversed.getHigh(3)); // ref 3 stays 3 + assertEquals(2, reversed.getLow(3)); // ref 4 -> 2 + } + + @Test + void testImmutability() { + // Ensure original BDD is not modified + Bdd original = new Bdd(2, 2, 0, 3, consumer -> { + consumer.accept(-1, 1, -1); // node 0 + consumer.accept(0, 3, -1); // node 1 + consumer.accept(1, 1, -1); // node 2 + }); + + // Get original values for comparison + int originalNodeCount = original.getNodeCount(); + int originalRootRef = original.getRootRef(); + + // Store original node values + int[] originalNodeValues = new int[original.getNodeCount() * 3]; + for (int i = 0; i < original.getNodeCount(); i++) { + originalNodeValues[i * 3] = original.getVariable(i); + originalNodeValues[i * 3 + 1] = original.getHigh(i); + originalNodeValues[i * 3 + 2] = original.getLow(i); + } + + Bdd reversed = NodeReversal.reverse(original); + + // Verify original is unchanged + assertEquals(originalNodeCount, original.getNodeCount()); + assertEquals(originalRootRef, original.getRootRef()); + + // Check original node values haven't changed + for (int i = 0; i < original.getNodeCount(); i++) { + assertEquals(originalNodeValues[i * 3], original.getVariable(i)); + assertEquals(originalNodeValues[i * 3 + 1], original.getHigh(i)); + assertEquals(originalNodeValues[i * 3 + 2], original.getLow(i)); + } + + // Ensure reversed is a different object + assertNotSame(original, reversed); + } + + @Test + void testTwoNodeBdd() { + // Test edge case with exactly 2 nodes + Bdd original = new Bdd(2, 1, 0, 2, consumer -> { + consumer.accept(-1, 1, -1); // node 0: terminal + consumer.accept(0, 1, -1); // node 1: simple condition + }); + + Bdd reversed = NodeReversal.reverse(original); + + // Should be unchanged (reversal returns as-is for <= 2 nodes) + assertEquals(2, reversed.getNodeCount()); + assertEquals(2, reversed.getRootRef()); + + // Check nodes are unchanged + assertEquals(-1, reversed.getVariable(0)); + assertEquals(1, reversed.getHigh(0)); + assertEquals(-1, reversed.getLow(0)); + + assertEquals(0, reversed.getVariable(1)); + assertEquals(1, reversed.getHigh(1)); + assertEquals(-1, reversed.getLow(1)); + } + + @Test + void testBddTraitReversalReturnsOriginalForSmallBdd() { + // Test that small BDDs return the original trait unchanged + NodeReversal reversal = new NodeReversal(); + + // Create a BddTrait with a 2-node BDD + Bdd bdd = new Bdd(2, 1, 1, 2, consumer -> { + consumer.accept(-1, 1, -1); // node 0: terminal + consumer.accept(0, 1, -1); // node 1: simple condition + }); + + EndpointBddTrait originalTrait = EndpointBddTrait.builder() + .parameters(Parameters.builder().build()) + .conditions(new ArrayList<>()) + .results(Collections.singletonList(NoMatchRule.INSTANCE)) + .bdd(bdd) + .build(); + + EndpointBddTrait reversedTrait = reversal.apply(originalTrait); + + // Should return the exact same trait object for small BDDs + assertSame(originalTrait, reversedTrait); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimizationTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimizationTest.java new file mode 100644 index 00000000000..e5eb794103d --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimizationTest.java @@ -0,0 +1,232 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; + +// Does some basic checks, but doesn't get too specific so we can easily change the sifting algorithm. +class SiftingOptimizationTest { + + @Test + void testBasicOptimization() { + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("A").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("B").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("C").type(ParameterType.STRING).build()) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("A")).build(), + Condition.builder().fn(TestHelpers.isSet("B")).build(), + Condition.builder().fn(TestHelpers.isSet("C")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + EndpointBddTrait originalTrait = EndpointBddTrait.from(cfg); + + SiftingOptimization optimizer = SiftingOptimization.builder().cfg(cfg).build(); + EndpointBddTrait optimizedTrait = optimizer.apply(originalTrait); + + // Basic checks + assertEquals(originalTrait.getConditions().size(), optimizedTrait.getConditions().size()); + assertEquals(originalTrait.getResults().size(), optimizedTrait.getResults().size()); + assertEquals(originalTrait.getBdd().getConditionCount(), optimizedTrait.getBdd().getConditionCount()); + assertEquals(originalTrait.getBdd().getResultCount(), optimizedTrait.getBdd().getResultCount()); + + // Size should be same or smaller + assertTrue(optimizedTrait.getBdd().getNodeCount() <= originalTrait.getBdd().getNodeCount()); + } + + @Test + void testDependenciesPreserved() { + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Input").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(Condition.builder() + .fn(TestHelpers.isSet("Input")) + .result(Identifier.of("hasInput")) + .build(), + Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasInput")), + Literal.of(true))) + .build(), + Condition.builder() + .fn(TestHelpers.isSet("Region")) + .build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + EndpointBddTrait originalTrait = EndpointBddTrait.from(cfg); + + SiftingOptimization optimizer = SiftingOptimization.builder().cfg(cfg).build(); + EndpointBddTrait optimizedTrait = optimizer.apply(originalTrait); + + // Verify the optimizer preserved the number of conditions + assertEquals(originalTrait.getConditions().size(), optimizedTrait.getConditions().size()); + assertEquals(originalTrait.getBdd().getConditionCount(), optimizedTrait.getBdd().getConditionCount()); + + // The fact that the optimization completes successfully and produces a valid BDD + // means dependencies were preserved (otherwise the BddCompiler would have failed + // during the optimization process). + + // Also verify results are preserved + assertEquals(originalTrait.getResults(), optimizedTrait.getResults()); + } + + @Test + void testSingleCondition() { + // Test a single condition edge case + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(params).addRule(rule).build(); + Cfg cfg = Cfg.from(ruleSet); + EndpointBddTrait originalTrait = EndpointBddTrait.from(cfg); + + SiftingOptimization optimizer = SiftingOptimization.builder().cfg(cfg).build(); + EndpointBddTrait optimizedTrait = optimizer.apply(originalTrait); + + // Should be unchanged or very similar + assertEquals(originalTrait.getBdd().getNodeCount(), optimizedTrait.getBdd().getNodeCount()); + assertEquals(1, optimizedTrait.getBdd().getConditionCount()); + assertEquals(originalTrait.getConditions(), optimizedTrait.getConditions()); + } + + @Test + void testEmptyRuleSet() { + // Test empty ruleset edge case + Parameters params = Parameters.builder().build(); + EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(params).build(); + + Cfg cfg = Cfg.from(ruleSet); + EndpointBddTrait originalTrait = EndpointBddTrait.from(cfg); + + SiftingOptimization optimizer = SiftingOptimization.builder().cfg(cfg).build(); + EndpointBddTrait optimizedTrait = optimizer.apply(originalTrait); + + assertEquals(0, optimizedTrait.getBdd().getConditionCount()); + assertEquals(originalTrait.getBdd().getResultCount(), optimizedTrait.getBdd().getResultCount()); + assertEquals(originalTrait.getResults(), optimizedTrait.getResults()); + } + + @Test + void testLargeReduction() { + // Create a ruleset that should benefit from optimization + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("A").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("B").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("C").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("D").type(ParameterType.STRING).build()) + .build(); + + // Multiple rules with overlapping conditions + Rule rule1 = EndpointRule.builder() + .conditions( + Condition.builder().fn(TestHelpers.isSet("A")).build(), + Condition.builder().fn(TestHelpers.isSet("B")).build()) + .endpoint(TestHelpers.endpoint("https://ab.example.com")); + + Rule rule2 = EndpointRule.builder() + .conditions( + Condition.builder().fn(TestHelpers.isSet("A")).build(), + Condition.builder().fn(TestHelpers.isSet("C")).build()) + .endpoint(TestHelpers.endpoint("https://ac.example.com")); + + Rule rule3 = EndpointRule.builder() + .conditions( + Condition.builder().fn(TestHelpers.isSet("B")).build(), + Condition.builder().fn(TestHelpers.isSet("D")).build()) + .endpoint(TestHelpers.endpoint("https://bd.example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule1) + .addRule(rule2) + .addRule(rule3) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + EndpointBddTrait originalTrait = EndpointBddTrait.from(cfg); + + SiftingOptimization optimizer = SiftingOptimization.builder() + .cfg(cfg) + .granularEffort(100_000, 10) // Allow more aggressive optimization + .build(); + EndpointBddTrait optimizedTrait = optimizer.apply(originalTrait); + + // Should maintain correctness + assertEquals(originalTrait.getConditions().size(), optimizedTrait.getConditions().size()); + assertEquals(originalTrait.getBdd().getConditionCount(), optimizedTrait.getBdd().getConditionCount()); + assertEquals(originalTrait.getBdd().getResultCount(), optimizedTrait.getBdd().getResultCount()); + assertEquals(originalTrait.getResults(), optimizedTrait.getResults()); + + // Often achieves some reduction + assertTrue(optimizedTrait.getBdd().getNodeCount() <= originalTrait.getBdd().getNodeCount()); + } + + @Test + void testNoImprovementReturnsOriginal() { + // Test that when no improvement is found, the original trait is returned + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(params).addRule(rule).build(); + Cfg cfg = Cfg.from(ruleSet); + EndpointBddTrait originalTrait = EndpointBddTrait.from(cfg); + + SiftingOptimization optimizer = SiftingOptimization.builder() + .cfg(cfg) + .coarseEffort(1, 1) // Minimal effort to likely find no improvement + .build(); + + EndpointBddTrait optimizedTrait = optimizer.apply(originalTrait); + + // For simple cases with minimal optimization effort, should return the same trait object + if (optimizedTrait.getBdd().getNodeCount() == originalTrait.getBdd().getNodeCount()) { + assertTrue(optimizedTrait == originalTrait); + } + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java new file mode 100644 index 00000000000..8fb61f6b89a --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java @@ -0,0 +1,349 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.Endpoint; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.evaluation.value.Value; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Not; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.ConditionReference; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.utils.ListUtils; + +class CfgBuilderTest { + + private CfgBuilder builder; + private EndpointRuleSet ruleSet; + + @BeforeEach + void setUp() { + Parameter region = Parameter.builder().name("region").type(ParameterType.STRING).build(); + Parameter useFips = Parameter.builder() + .name("useFips") + .type(ParameterType.BOOLEAN) + .defaultValue(Value.booleanValue(false)) + .required(true) + .build(); + ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(region).addParameter(useFips).build()) + .build(); + + builder = new CfgBuilder(ruleSet); + } + + @Test + void buildRequiresNonNullRoot() { + assertThrows(NullPointerException.class, () -> builder.build(null)); + } + + @Test + void buildCreatesValidCfg() { + CfgNode root = ResultNode.terminal(); + Cfg cfg = builder.build(root); + + assertNotNull(cfg); + assertSame(root, cfg.getRoot()); + assertEquals(ruleSet.getParameters(), cfg.getParameters()); + } + + @Test + void createResultNodesCachesIdenticalRules() { + Rule rule1 = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); + Rule rule2 = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); + + CfgNode node1 = builder.createResult(rule1); + CfgNode node2 = builder.createResult(rule2); + + assertSame(node1, node2); + } + + @Test + void createResultNodesDistinguishesDifferentRules() { + Rule rule1 = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example1.com")); + Rule rule2 = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example2.com")); + + CfgNode node1 = builder.createResult(rule1); + CfgNode node2 = builder.createResult(rule2); + + assertNotSame(node1, node2); + } + + @Test + void createResultStripsConditionsBeforeCaching() { + Condition cond = Condition.builder().fn(TestHelpers.isSet("region")).build(); + Rule ruleWithCondition = EndpointRule.builder() + .condition(cond) + .endpoint(TestHelpers.endpoint("https://example.com")); + Rule ruleWithoutCondition = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); + + CfgNode node1 = builder.createResult(ruleWithCondition); + CfgNode node2 = builder.createResult(ruleWithoutCondition); + + assertSame(node1, node2); + } + + @Test + void createConditionCachesIdenticalNodes() { + Condition cond = Condition.builder().fn(TestHelpers.isSet("region")).build(); + CfgNode trueBranch = ResultNode.terminal(); + CfgNode falseBranch = ResultNode.terminal(); + + CfgNode node1 = builder.createCondition(cond, trueBranch, falseBranch); + CfgNode node2 = builder.createCondition(cond, trueBranch, falseBranch); + + assertSame(node1, node2); + } + + @Test + void createConditionDistinguishesDifferentBranches() { + Condition cond = Condition.builder().fn(TestHelpers.isSet("region")).build(); + CfgNode trueBranch1 = builder.createResult( + EndpointRule.builder().endpoint(TestHelpers.endpoint("https://true1.com"))); + CfgNode trueBranch2 = builder.createResult( + EndpointRule.builder().endpoint(TestHelpers.endpoint("https://true2.com"))); + CfgNode falseBranch = ResultNode.terminal(); + + CfgNode node1 = builder.createCondition(cond, trueBranch1, falseBranch); + CfgNode node2 = builder.createCondition(cond, trueBranch2, falseBranch); + + assertNotSame(node1, node2); + } + + @Test + void createConditionReferenceHandlesSimpleCondition() { + Condition cond = Condition.builder().fn(TestHelpers.isSet("region")).build(); + + ConditionReference ref = builder.createConditionReference(cond); + + assertNotNull(ref); + assertEquals(cond, ref.getCondition()); + assertFalse(ref.isNegated()); + } + + @Test + void createConditionReferenceCachesIdenticalConditions() { + Condition cond1 = Condition.builder().fn(TestHelpers.isSet("region")).build(); + Condition cond2 = Condition.builder().fn(TestHelpers.isSet("region")).build(); + + ConditionReference ref1 = builder.createConditionReference(cond1); + ConditionReference ref2 = builder.createConditionReference(cond2); + + assertSame(ref1, ref2); + } + + @Test + void createConditionReferenceHandlesNegation() { + Condition innerCond = Condition.builder().fn(TestHelpers.isSet("region")).build(); + Condition negatedCond = Condition.builder().fn(Not.ofExpressions(innerCond.getFunction())).build(); + ConditionReference ref = builder.createConditionReference(negatedCond); + + assertNotNull(ref); + assertTrue(ref.isNegated()); + assertEquals(innerCond.getFunction(), ref.getCondition().getFunction()); + } + + @Test + void createConditionReferenceSharesInfoForNegatedAndNonNegated() { + Condition cond = Condition.builder().fn(TestHelpers.isSet("region")).build(); + Condition negatedCond = Condition.builder().fn(Not.ofExpressions(cond.getFunction())).build(); + + ConditionReference ref1 = builder.createConditionReference(cond); + ConditionReference ref2 = builder.createConditionReference(negatedCond); + + assertEquals(ref1.getCondition(), ref2.getCondition()); + assertFalse(ref1.isNegated()); + assertTrue(ref2.isNegated()); + } + + @Test + void createConditionReferenceHandlesBooleanEqualsCanonicalizations() { + // Test booleanEquals(useFips, false) -> booleanEquals(useFips, true) with negation + Expression ref = Expression.getReference(Identifier.of("useFips")); + Condition cond = Condition.builder().fn(BooleanEquals.ofExpressions(ref, false)).build(); + + ConditionReference condRef = builder.createConditionReference(cond); + + // Should be canonicalized to booleanEquals(useFips, true) with negation + assertTrue(condRef.isNegated()); + assertInstanceOf(BooleanEquals.class, condRef.getCondition().getFunction()); + + BooleanEquals fn = (BooleanEquals) condRef.getCondition().getFunction(); + assertEquals(ref, fn.getArguments().get(0)); + assertEquals(Literal.booleanLiteral(true), fn.getArguments().get(1)); + } + + @Test + void createConditionReferenceDoesNotCanonicalizeWithoutDefault() { + // Test that booleanEquals(region, false) is not canonicalized (no default) + Expression ref = Expression.getReference(Identifier.of("region")); + Condition cond = Condition.builder().fn(BooleanEquals.ofExpressions(ref, false)).build(); + + ConditionReference condRef = builder.createConditionReference(cond); + + assertFalse(condRef.isNegated()); + assertEquals(cond.getFunction(), condRef.getCondition().getFunction()); + } + + @Test + void createConditionReferenceHandlesCommutativeCanonicalizations() { + Expression ref = Expression.getReference(Identifier.of("region")); + + // Create conditions with different argument orders + Condition cond1 = Condition.builder().fn(StringEquals.ofExpressions(ref, "us-east-1")).build(); + Condition cond2 = Condition.builder().fn(StringEquals.ofExpressions(Expression.of("us-east-1"), ref)).build(); + + ConditionReference ref1 = builder.createConditionReference(cond1); + ConditionReference ref2 = builder.createConditionReference(cond2); + + // Both should produce equivalent canonicalized references. + // They should have the same underlying condition after canonicalization + assertEquals(ref1.getCondition(), ref2.getCondition()); + assertEquals(ref1.isNegated(), ref2.isNegated()); + } + + @Test + void createConditionReferenceHandlesVariableBinding() { + Condition cond = Condition.builder() + .fn(TestHelpers.parseUrl("{url}")) + .result(Identifier.of("parsedUrl")) + .build(); + + ConditionReference ref = builder.createConditionReference(cond); + + assertNotNull(ref); + assertEquals(cond, ref.getCondition()); + } + + @Test + void createConditionHandlesComplexNesting() { + // Build a nested structure to test caching + CfgNode endpoint1 = builder.createResult( + EndpointRule.builder().endpoint(TestHelpers.endpoint("https://endpoint1.com"))); + CfgNode endpoint2 = builder.createResult( + EndpointRule.builder().endpoint(TestHelpers.endpoint("https://endpoint2.com"))); + CfgNode errorNode = builder.createResult(ErrorRule.builder().error("Invalid configuration")); + + Condition cond1 = Condition.builder().fn(TestHelpers.isSet("region")).build(); + Condition cond2 = Condition.builder().fn(TestHelpers.stringEquals("region", "us-east-1")).build(); + + // Create nested conditions + CfgNode inner = builder.createCondition(cond2, endpoint1, endpoint2); + CfgNode outer = builder.createCondition(cond1, inner, errorNode); + + assertInstanceOf(ConditionNode.class, outer); + ConditionNode outerNode = (ConditionNode) outer; + assertEquals(cond1, outerNode.getCondition().getCondition()); + assertSame(inner, outerNode.getTrueBranch()); + assertSame(errorNode, outerNode.getFalseBranch()); + } + + @Test + void createConditionReferenceIgnoresNegationWithVariableBinding() { + // Negation with variable binding should not unwrap + Condition innerCond = Condition.builder().fn(TestHelpers.isSet("region")).build(); + + Condition negatedWithBinding = Condition.builder() + .fn(Not.ofExpressions(innerCond.getFunction())) + .result(Identifier.of("notRegionSet")) + .build(); + + ConditionReference ref = builder.createConditionReference(negatedWithBinding); + + // Should not be treated as simple negation due to variable binding + assertFalse(ref.isNegated()); + assertInstanceOf(Not.class, ref.getCondition().getFunction()); + assertEquals(negatedWithBinding, ref.getCondition()); + } + + @Test + void createResultPreservesHeadersAndPropertiesInSignature() { + // Create endpoints with same URL but different headers + Map> headers1 = new HashMap<>(); + headers1.put("x-custom", Collections.singletonList(Expression.of("value1"))); + + Map> headers2 = new HashMap<>(); + headers2.put("x-custom", Collections.singletonList(Expression.of("value2"))); + + Rule rule1 = EndpointRule.builder() + .endpoint(Endpoint.builder() + .url(Expression.of("https://example.com")) + .headers(headers1) + .build()); + Rule rule2 = EndpointRule.builder() + .endpoint(Endpoint.builder() + .url(Expression.of("https://example.com")) + .headers(headers2) + .build()); + + EndpointRuleSet ruleSetWithHeaders = EndpointRuleSet.builder() + .parameters(ruleSet.getParameters()) + .rules(ListUtils.of(rule1, rule2)) + .build(); + CfgBuilder convergenceBuilder = new CfgBuilder(ruleSetWithHeaders); + + CfgNode node1 = convergenceBuilder.createResult(rule1); + CfgNode node2 = convergenceBuilder.createResult(rule2); + + // Different headers mean different signatures - no convergence + assertNotSame(node1, node2); + } + + @Test + void createResultDistinguishesEndpointsWithDifferentStructure() { + // Create rules with different endpoint structures + Rule rule1 = EndpointRule.builder() + .endpoint(TestHelpers.endpoint("https://{region}.example.com")); + Rule rule2 = EndpointRule.builder() + .endpoint(TestHelpers.endpoint("https://example.com/{region}")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("region") + .type(ParameterType.STRING) + .defaultValue(Value.stringValue("a")) + .required(true) + .build()) + .build(); + + EndpointRuleSet ruleSetWithEndpoints = EndpointRuleSet.builder() + .parameters(params) + .rules(ListUtils.of(rule1, rule2)) + .build(); + CfgBuilder convergenceBuilder = new CfgBuilder(ruleSetWithEndpoints); + + CfgNode node1 = convergenceBuilder.createResult(rule1); + CfgNode node2 = convergenceBuilder.createResult(rule2); + + // Different structures should not converge + assertNotSame(node1, node2); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgTest.java new file mode 100644 index 00000000000..2aad5ab9484 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgTest.java @@ -0,0 +1,242 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Set; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.logic.ConditionReference; +import software.amazon.smithy.rulesengine.logic.TestHelpers; + +class CfgTest { + + @Test + void gettersReturnConstructorValues() { + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().build()) + .build(); + CfgNode root = ResultNode.terminal(); + + Cfg cfg = new Cfg(ruleSet, root); + + assertSame(ruleSet.getParameters(), cfg.getParameters()); + assertSame(root, cfg.getRoot()); + } + + @Test + void fromCreatesSimpleCfg() { + EndpointRule rule = EndpointRule.builder() + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().build()) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + + assertNotNull(cfg); + assertNotNull(cfg.getRoot()); + + // Root should be a result node for a simple endpoint rule + assertInstanceOf(ResultNode.class, cfg.getRoot()); + ResultNode resultNode = (ResultNode) cfg.getRoot(); + assertEquals(rule.withConditions(Collections.emptyList()), resultNode.getResult()); + } + + @Test + void fromCreatesConditionNode() { + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("region").type(ParameterType.STRING).build()) + .build(); + + EndpointRule rule = EndpointRule.builder() + .condition(Condition.builder().fn(TestHelpers.isSet("region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + + // Root should be a condition node + assertInstanceOf(ConditionNode.class, cfg.getRoot()); + ConditionNode condNode = (ConditionNode) cfg.getRoot(); + assertEquals("isSet(region)", condNode.getCondition().getCondition().toString()); + } + + @Test + void fromHandlesMultipleRules() { + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("region").type(ParameterType.STRING).build()) + .build(); + + // TreeRule with isSet check followed by stringEquals + Rule treeRule = TreeRule.builder() + .condition(Condition.builder().fn(TestHelpers.isSet("region")).build()) + .treeRule( + EndpointRule.builder() + .condition( + Condition.builder().fn(TestHelpers.stringEquals("region", "us-east-1")).build()) + .endpoint(TestHelpers.endpoint("https://us-east-1.example.com")), + EndpointRule.builder() + .condition( + Condition.builder().fn(TestHelpers.stringEquals("region", "eu-west-1")).build()) + .endpoint(TestHelpers.endpoint("https://eu-west-1.example.com"))); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(treeRule) + .addRule(ErrorRule.builder().error("Unknown region")) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + + assertInstanceOf(ConditionNode.class, cfg.getRoot()); + } + + @Test + void iteratorVisitsAllNodes() { + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("region").type(ParameterType.STRING).build()) + .build(); + + Rule rule1 = EndpointRule.builder() + .condition(Condition.builder().fn(TestHelpers.isSet("region")).build()) + .endpoint(TestHelpers.endpoint("https://with-region.com")); + + Rule rule2 = EndpointRule.builder() + .endpoint(TestHelpers.endpoint("https://no-region.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule1) + .addRule(rule2) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + + Set visited = new HashSet<>(); + for (CfgNode node : cfg) { + visited.add(node); + } + + // Should have at least 3 nodes: condition node and 2 result nodes + assertTrue(visited.size() >= 3); + } + + @Test + void iteratorHandlesEmptyCfg() { + CfgNode root = ResultNode.terminal(); + Cfg cfg = new Cfg((EndpointRuleSet) null, root); + + List nodes = new ArrayList<>(); + for (CfgNode node : cfg) { + nodes.add(node); + } + + assertEquals(1, nodes.size()); + assertSame(root, nodes.get(0)); + } + + @Test + void iteratorThrowsNoSuchElementException() { + CfgNode root = ResultNode.terminal(); + Cfg cfg = new Cfg((EndpointRuleSet) null, root); + + Iterator iterator = cfg.iterator(); + assertTrue(iterator.hasNext()); + iterator.next(); + assertFalse(iterator.hasNext()); + + assertThrows(NoSuchElementException.class, iterator::next); + } + + @Test + void iteratorDoesNotVisitNodesTwice() { + // Create a diamond-shaped CFG where multiple paths lead to the same node + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("a").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("b").type(ParameterType.STRING).build()) + .build(); + + CfgBuilder builder = new CfgBuilder(EndpointRuleSet.builder() + .parameters(params) + .build()); + + CfgNode sharedResult = builder.createResult( + EndpointRule.builder().endpoint(TestHelpers.endpoint("https://shared.com"))); + + Condition cond1 = Condition.builder().fn(TestHelpers.isSet("a")).build(); + Condition cond2 = Condition.builder().fn(TestHelpers.isSet("b")).build(); + + ConditionReference ref1 = builder.createConditionReference(cond1); + ConditionReference ref2 = builder.createConditionReference(cond2); + + // Both conditions can lead to the same result + CfgNode branch1 = builder.createCondition(ref2, sharedResult, sharedResult); + CfgNode root = builder.createCondition(ref1, branch1, sharedResult); + + Cfg cfg = builder.build(root); + + List visitedNodes = new ArrayList<>(); + for (CfgNode node : cfg) { + visitedNodes.add(node); + } + + // Count occurrences of sharedResult + long sharedResultCount = visitedNodes.stream() + .filter(node -> node == sharedResult) + .count(); + + assertEquals(1, sharedResultCount, "Shared node should only be visited once"); + } + + @Test + void equalsAndHashCodeBasedOnRoot() { + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().build()) + .build(); + CfgNode root1 = ResultNode.terminal(); + CfgNode root2 = new ResultNode( + EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com"))); + + Cfg cfg1a = new Cfg(ruleSet, root1); + Cfg cfg1b = new Cfg(ruleSet, root1); + Cfg cfg2 = new Cfg(ruleSet, root2); + + // Same root + assertEquals(cfg1a, cfg1b); + assertEquals(cfg1a.hashCode(), cfg1b.hashCode()); + + // Different root + assertNotEquals(cfg1a, cfg2); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransformTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransformTest.java new file mode 100644 index 00000000000..d2fbac52245 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransformTest.java @@ -0,0 +1,575 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Not; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.utils.ListUtils; + +class CoalesceTransformTest { + + @Test + void testActualCoalescing() { + // Test with substring which returns a string (has zero value "") + // This should actually coalesce + Condition checkInput = Condition.builder() + .fn(TestHelpers.isSet("Input")) + .build(); + + Condition bind = Condition.builder() + .fn(TestHelpers.substring("Input", 0, 5, false)) + .result(Identifier.of("prefix")) + .build(); + + Condition use = Condition.builder() + .fn(StringEquals.ofExpressions( + Expression.getReference(Identifier.of("prefix")), + Literal.of("https"))) + .result(Identifier.of("isHttps")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(checkInput, bind, use) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Input") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Should coalesce because substring returns string which has zero value + Rule transformedRule = transformed.getRules().get(0); + assertEquals(2, transformedRule.getConditions().size()); // isSet + coalesced + + Condition coalesced = transformedRule.getConditions().get(1); + assertTrue(coalesced.getResult().isPresent()); + assertEquals("isHttps", coalesced.getResult().get().toString()); + } + + @Test + void testSimpleBindThenUsePattern() { + // parseUrl returns a record type which doesn't have a zero value + // So it won't be coalesced + Condition checkEndpoint = Condition.builder() + .fn(TestHelpers.isSet("Endpoint")) + .build(); + + Condition bind = Condition.builder() + .fn(TestHelpers.parseUrl("Endpoint")) + .result(Identifier.of("url")) + .build(); + + Condition use = Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url")), + "scheme")) + .result(Identifier.of("scheme")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(checkEndpoint, bind, use) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Endpoint") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + Rule transformedRule = transformed.getRules().get(0); + List conditions = transformedRule.getConditions(); + + // Should not coalesce because parseUrl returns a record without zero value + assertEquals(3, conditions.size()); + } + + @Test + void testDoesNotCoalesceWhenVariableUsedMultipleTimes() { + // Variable is used multiple times - should not coalesce + Condition checkEndpoint = Condition.builder() + .fn(TestHelpers.isSet("Endpoint")) + .build(); + + Condition bind = Condition.builder() + .fn(TestHelpers.parseUrl("Endpoint")) + .result(Identifier.of("url")) + .build(); + + Condition use1 = Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url")), + "scheme")) + .result(Identifier.of("scheme")) + .build(); + + Condition use2 = Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url")), + "authority")) + .result(Identifier.of("authority")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(checkEndpoint, bind, use1, use2) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Endpoint") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Should not coalesce because 'url' is used twice + Rule transformedRule = transformed.getRules().get(0); + assertEquals(4, transformedRule.getConditions().size()); + } + + @Test + void testDoesNotCoalesceIsSetFunction() { + // isSet functions should not be coalesced + Condition bind = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Condition use = Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasRegion")), + Literal.of(true))) + .result(Identifier.of("regionIsSet")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(bind, use) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Region") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Should coalesce because BooleanType has a zero value (false) + // The actual behavior is that it DOES coalesce boolean operations + Rule transformedRule = transformed.getRules().get(0); + assertEquals(1, transformedRule.getConditions().size()); + } + + @Test + void testMultipleCoalescesInSameRule() { + // parseUrl returns a record type which doesn't have a zero value + // So these won't be coalesced + Condition checkEndpoint1 = Condition.builder() + .fn(TestHelpers.isSet("Endpoint1")) + .build(); + + Condition checkEndpoint2 = Condition.builder() + .fn(TestHelpers.isSet("Endpoint2")) + .build(); + + Condition bind1 = Condition.builder() + .fn(TestHelpers.parseUrl("Endpoint1")) + .result(Identifier.of("url1")) + .build(); + + Condition use1 = Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url1")), + "scheme")) + .result(Identifier.of("scheme1")) + .build(); + + Condition bind2 = Condition.builder() + .fn(TestHelpers.parseUrl("Endpoint2")) + .result(Identifier.of("url2")) + .build(); + + Condition use2 = Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url2")), + "scheme")) + .result(Identifier.of("scheme2")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(checkEndpoint1, checkEndpoint2, bind1, use1, bind2, use2) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Endpoint1") + .type(ParameterType.STRING) + .build()) + .addParameter(Parameter.builder() + .name("Endpoint2") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Won't coalesce because parseUrl returns record without zero value + Rule transformedRule = transformed.getRules().get(0); + assertEquals(6, transformedRule.getConditions().size()); + } + + @Test + void testCoalesceWithStringFunctions() { + // Test coalescing with string manipulation functions + Condition checkInput = Condition.builder() + .fn(TestHelpers.isSet("Input")) + .build(); + + Condition bind = Condition.builder() + .fn(TestHelpers.substring("Input", 0, 5, false)) + .result(Identifier.of("prefix")) + .build(); + + Condition use = Condition.builder() + .fn(StringEquals.ofExpressions( + Expression.getReference(Identifier.of("prefix")), + Literal.of("https"))) + .result(Identifier.of("isHttps")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(checkInput, bind, use) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Input") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Should coalesce string functions that have zero values + Rule transformedRule = transformed.getRules().get(0); + assertEquals(2, transformedRule.getConditions().size()); // isSet + coalesced + + Condition coalesced = transformedRule.getConditions().get(1); + assertTrue(coalesced.getResult().isPresent()); + assertEquals("isHttps", coalesced.getResult().get().toString()); + } + + @Test + void testDoesNotCoalesceWhenNotImmediatelyFollowing() { + // Bind and use are not immediately following each other + Condition checkEndpoint = Condition.builder() + .fn(TestHelpers.isSet("Endpoint")) + .build(); + + Condition bind = Condition.builder() + .fn(TestHelpers.parseUrl("Endpoint")) + .result(Identifier.of("url")) + .build(); + + Condition intermediate = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Condition use = Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url")), + "scheme")) + .result(Identifier.of("scheme")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(checkEndpoint, bind, intermediate, use) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Endpoint") + .type(ParameterType.STRING) + .build()) + .addParameter(Parameter.builder() + .name("Region") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Should NOT coalesce because bind and use are not adjacent + Rule transformedRule = transformed.getRules().get(0); + assertEquals(4, transformedRule.getConditions().size()); + } + + @Test + void testCoalesceCaching() { + // parseUrl returns a record type which doesn't have a zero value + // So these won't be coalesced + Condition check1 = Condition.builder() + .fn(TestHelpers.isSet("Endpoint")) + .build(); + + Rule rule1 = EndpointRule.builder() + .conditions( + check1, + Condition.builder() + .fn(TestHelpers.parseUrl("Endpoint")) + .result(Identifier.of("url")) + .build(), + Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url")), + "scheme")) + .result(Identifier.of("scheme")) + .build()) + .endpoint(TestHelpers.endpoint("https://example1.com")); + + Condition check2 = Condition.builder() + .fn(TestHelpers.isSet("Endpoint")) + .build(); + + Rule rule2 = EndpointRule.builder() + .conditions( + check2, + Condition.builder() + .fn(TestHelpers.parseUrl("Endpoint")) + .result(Identifier.of("url")) + .build(), + Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url")), + "scheme")) + .result(Identifier.of("scheme")) + .build()) + .endpoint(TestHelpers.endpoint("https://example2.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Endpoint") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .rules(ListUtils.of(rule1, rule2)) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Won't coalesce because parseUrl returns record without zero value + assertEquals(3, transformed.getRules().get(0).getConditions().size()); + assertEquals(3, transformed.getRules().get(1).getConditions().size()); + } + + @Test + void testCoalesceInErrorRule() { + // parseUrl returns a record type which doesn't have a zero value + // So it won't be coalesced + Condition checkEndpoint = Condition.builder() + .fn(TestHelpers.isSet("Endpoint")) + .build(); + + Condition bind = Condition.builder() + .fn(TestHelpers.parseUrl("Endpoint")) + .result(Identifier.of("url")) + .build(); + + Condition use = Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url")), + "scheme")) + .result(Identifier.of("scheme")) + .build(); + + Condition check = Condition.builder() + .fn(Not.ofExpressions( + StringEquals.ofExpressions( + Expression.getReference(Identifier.of("scheme")), + Literal.of("https")))) + .build(); + + Rule rule = ErrorRule.builder() + .conditions(checkEndpoint, bind, use, check) + .error(Literal.of("Endpoint must use HTTPS")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Endpoint") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Won't coalesce because parseUrl returns record without zero value + Rule transformedRule = transformed.getRules().get(0); + assertEquals(4, transformedRule.getConditions().size()); + } + + @Test + void testCoalesceWithBooleanType() { + // Test coalescing with boolean-returning functions + Condition checkInput = Condition.builder() + .fn(TestHelpers.isSet("Input")) + .build(); + + Condition bind = Condition.builder() + .fn(TestHelpers.isValidHostLabel("Input", false)) + .result(Identifier.of("isValid")) + .build(); + + Condition use = Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("isValid")), + Literal.of(true))) + .result(Identifier.of("validLabel")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(checkInput, bind, use) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Input") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Should coalesce boolean functions (they have zero value of false) + Rule transformedRule = transformed.getRules().get(0); + assertEquals(2, transformedRule.getConditions().size()); // isSet + coalesced + } + + @Test + void testDoesNotCoalesceWhenVariableUsedElsewhere() { + // Variable is used in a different branch + Condition checkEndpoint = Condition.builder() + .fn(TestHelpers.isSet("Endpoint")) + .build(); + + Condition bind = Condition.builder() + .fn(TestHelpers.parseUrl("Endpoint")) + .result(Identifier.of("url")) + .build(); + + Condition use = Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url")), + "scheme")) + .result(Identifier.of("scheme")) + .build(); + + Rule innerRule = EndpointRule.builder() + .conditions(Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url")), + "authority")) + .result(Identifier.of("authority")) + .build()) + .endpoint(TestHelpers.endpoint("https://inner.example.com")); + + Rule treeRule = TreeRule.builder() + .conditions(checkEndpoint, bind, use) + .treeRule(innerRule); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Endpoint") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(treeRule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Should NOT coalesce because 'url' is used in the inner rule + TreeRule transformedTree = (TreeRule) transformed.getRules().get(0); + assertEquals(3, transformedTree.getConditions().size()); // isSet + bind + use (not coalesced) + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDependencyGraphTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDependencyGraphTest.java new file mode 100644 index 00000000000..f118ce41e65 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDependencyGraphTest.java @@ -0,0 +1,120 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.logic.TestHelpers; + +class ConditionDependencyGraphTest { + + @Test + void testBasicVariableDependency() { + // Condition that defines a variable + Condition definer = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + // Condition that uses the variable + Condition user = Condition.builder() + .fn(BooleanEquals.ofExpressions(Expression.of("{hasRegion}"), Expression.of(true))) + .build(); + + List conditions = Arrays.asList(definer, user); + ConditionDependencyGraph graph = new ConditionDependencyGraph(conditions); + + // Definer has no dependencies + assertTrue(graph.getDependencies(definer).isEmpty()); + + // User depends on definer + Set userDeps = graph.getDependencies(user); + assertEquals(1, userDeps.size()); + assertTrue(userDeps.contains(definer)); + } + + @Test + void testIsSetDependencyForNonIsSetCondition() { + // isSet condition for a variable + Condition isSetCondition = Condition.builder().fn(TestHelpers.isSet("Region")).build(); + // Non-isSet condition using the same variable + Condition userCondition = Condition.builder() + .fn(StringEquals.ofExpressions(Expression.of("{Region}"), Expression.of("us-east-1"))) + .build(); + + List conditions = Arrays.asList(isSetCondition, userCondition); + ConditionDependencyGraph graph = new ConditionDependencyGraph(conditions); + + // Non-isSet condition depends on isSet for undefined variables + Set userDeps = graph.getDependencies(userCondition); + assertEquals(1, userDeps.size()); + assertTrue(userDeps.contains(isSetCondition)); + } + + @Test + void testMultipleDependencies() { + // Define two variables + Condition definer1 = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Condition definer2 = Condition.builder() + .fn(TestHelpers.isSet("Bucket")) + .result(Identifier.of("hasBucket")) + .build(); + + // Use both variables + Condition user = Condition.builder() + .fn(BooleanEquals.ofExpressions( + BooleanEquals.ofExpressions(Expression.of("{hasRegion}"), Expression.of(true)), + BooleanEquals.ofExpressions(Expression.of("{hasBucket}"), Expression.of(true)))) + .build(); + + List conditions = Arrays.asList(definer1, definer2, user); + ConditionDependencyGraph graph = new ConditionDependencyGraph(conditions); + + // User depends on both definers + Set userDeps = graph.getDependencies(user); + assertEquals(2, userDeps.size()); + assertTrue(userDeps.contains(definer1)); + assertTrue(userDeps.contains(definer2)); + } + + @Test + void testUnknownConditionReturnsEmptyDependencies() { + Condition known = Condition.builder().fn(TestHelpers.isSet("Region")).build(); + Condition unknown = Condition.builder().fn(TestHelpers.isSet("Bucket")).build(); + + List conditions = Collections.singletonList(known); + ConditionDependencyGraph graph = new ConditionDependencyGraph(conditions); + + // Getting dependencies for unknown condition returns empty set + assertTrue(graph.getDependencies(unknown).isEmpty()); + } + + @Test + void testGraphSize() { + Condition cond1 = Condition.builder().fn(TestHelpers.isSet("A")).build(); + Condition cond2 = Condition.builder().fn(TestHelpers.isSet("B")).build(); + Condition cond3 = Condition.builder().fn(TestHelpers.isSet("C")).build(); + + List conditions = Arrays.asList(cond1, cond2, cond3); + ConditionDependencyGraph graph = new ConditionDependencyGraph(conditions); + + assertEquals(3, graph.size()); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriterTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriterTest.java new file mode 100644 index 00000000000..29b9c87f03e --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriterTest.java @@ -0,0 +1,181 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.HashMap; +import java.util.Map; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.RecordLiteral; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.TupleLiteral; +import software.amazon.smithy.utils.ListUtils; +import software.amazon.smithy.utils.MapUtils; + +class ReferenceRewriterTest { + + @Test + void testSimpleReferenceReplacement() { + // Create a rewriter that replaces "x" with "y" + Map replacements = new HashMap<>(); + replacements.put("x", Expression.getReference(Identifier.of("y"))); + + TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); + + // Test rewriting a simple reference + Reference original = Expression.getReference(Identifier.of("x")); + Expression rewritten = rewriter.rewrite(original); + + assertEquals("y", ((Reference) rewritten).getName().toString()); + } + + @Test + void testNoRewriteNeeded() { + // Create a rewriter with no relevant replacements + Map replacements = new HashMap<>(); + replacements.put("x", Expression.getReference(Identifier.of("y"))); + + TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); + + // Reference to "z" should not be rewritten + Reference original = Expression.getReference(Identifier.of("z")); + Expression rewritten = rewriter.rewrite(original); + + assertEquals(original, rewritten); + } + + @Test + void testRewriteInStringLiteral() { + // Create a string literal with template variable + Template template = Template.fromString("Value is {x}"); + Literal original = Literal.stringLiteral(template); + + Map replacements = new HashMap<>(); + replacements.put("x", Expression.getReference(Identifier.of("newVar"))); + + TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); + Expression rewritten = rewriter.rewrite(original); + + assertInstanceOf(StringLiteral.class, rewritten); + StringLiteral rewrittenStr = (StringLiteral) rewritten; + assertTrue(rewrittenStr.toString().contains("newVar")); + assertNotEquals(original, rewritten); + } + + @Test + void testRewriteInTupleLiteral() { + // Create a tuple with references + Literal original = Literal.tupleLiteral(ListUtils.of(Literal.of("constant"), Literal.of("{x}"))); + + Map replacements = new HashMap<>(); + replacements.put("x", Expression.getReference(Identifier.of("replaced"))); + + TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); + Expression rewritten = rewriter.rewrite(original); + + assertInstanceOf(TupleLiteral.class, rewritten); + TupleLiteral rewrittenTuple = (TupleLiteral) rewritten; + assertEquals(2, rewrittenTuple.members().size()); + assertTrue(rewrittenTuple.members().get(1).toString().contains("replaced")); + } + + @Test + void testRewriteInRecordLiteral() { + // Create a record with references + Literal original = Literal.recordLiteral(MapUtils.of( + Identifier.of("field1"), + Literal.of("value1"), + Identifier.of("field2"), + Literal.of("{x}"))); + + Map replacements = new HashMap<>(); + replacements.put("x", Expression.getReference(Identifier.of("newX"))); + + TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); + Expression rewritten = rewriter.rewrite(original); + + assertInstanceOf(RecordLiteral.class, rewritten); + RecordLiteral rewrittenRecord = (RecordLiteral) rewritten; + assertEquals(2, rewrittenRecord.members().size()); + assertTrue(rewrittenRecord.members().get(Identifier.of("field2")).toString().contains("newX")); + } + + @Test + void testRewriteInLibraryFunction() { + // Create a function that uses references + Expression original = StringEquals.ofExpressions( + Expression.getReference(Identifier.of("x")), + Literal.of("test")); + + Map replacements = new HashMap<>(); + replacements.put("x", Expression.getReference(Identifier.of("replacedVar"))); + + TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); + Expression rewritten = rewriter.rewrite(original); + + assertTrue(rewritten.toString().contains("replacedVar")); + assertNotEquals(original, rewritten); + } + + @Test + void testMultipleReplacements() { + // Create a function with multiple references + Expression original = BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("a")), + Expression.getReference(Identifier.of("b"))); + + Map replacements = new HashMap<>(); + replacements.put("a", Expression.getReference(Identifier.of("x"))); + replacements.put("b", Expression.getReference(Identifier.of("y"))); + + TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); + Expression rewritten = rewriter.rewrite(original); + + assertTrue(rewritten.toString().contains("x")); + assertTrue(rewritten.toString().contains("y")); + } + + @Test + void testNestedRewriting() { + // Create nested functions with references + Expression inner = IsSet.ofExpressions(Expression.getReference(Identifier.of("x"))); + Expression original = BooleanEquals.ofExpressions(inner, Literal.of(true)); + + Map replacements = new HashMap<>(); + replacements.put("x", Expression.getReference(Identifier.of("newVar"))); + + TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); + Expression rewritten = rewriter.rewrite(original); + + assertTrue(rewritten.toString().contains("newVar")); + assertNotEquals(original, rewritten); + } + + @Test + void testStaticStringNotRewritten() { + // Static strings without templates should not be rewritten + Literal original = Literal.of("static string"); + + Map replacements = new HashMap<>(); + replacements.put("x", Expression.getReference(Identifier.of("y"))); + + TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); + Expression rewritten = rewriter.rewrite(original); + + assertEquals(original, rewritten); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java new file mode 100644 index 00000000000..c5d7b5c405f --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java @@ -0,0 +1,231 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.Endpoint; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; + +public class SsaTransformTest { + + @Test + void testNoDisambiguationNeeded() { + // When variables are not shadowed, they should remain unchanged + Parameter bucketParam = Parameter.builder() + .name("Bucket") + .type(ParameterType.STRING) + .build(); + + Condition condition1 = Condition.builder() + .fn(StringEquals.ofExpressions(Expression.of("Bucket"), Expression.of("mybucket"))) + .result("bucketMatches") + .build(); + + EndpointRule rule = EndpointRule.builder() + .conditions(Collections.singletonList(condition1)) + .endpoint(endpoint("https://example.com")); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(bucketParam).build()) + .rules(Collections.singletonList(rule)) + .version("1.0") + .build(); + + EndpointRuleSet result = SsaTransform.transform(original); + + assertEquals(original, result); + } + + @Test + void testSimpleShadowing() { + // Test when the same variable name is bound to different expressions + Parameter param = Parameter.builder() + .name("Input") + .type(ParameterType.STRING) + .build(); + + // Create rules that will have shadowing after disambiguation + List rules = Arrays.asList( + createRuleWithBinding("Input", "a", "temp", "https://branch1.com"), + createRuleWithBinding("Input", "b", "temp", "https://branch2.com")); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(param).build()) + .rules(rules) + .version("1.0") + .build(); + + EndpointRuleSet result = SsaTransform.transform(original); + + List resultRules = result.getRules(); + assertEquals(2, resultRules.size()); + + EndpointRule resultRule1 = (EndpointRule) resultRules.get(0); + assertEquals("temp_ssa_1", resultRule1.getConditions().get(0).getResult().get().toString()); + + EndpointRule resultRule2 = (EndpointRule) resultRules.get(1); + assertEquals("temp_ssa_2", resultRule2.getConditions().get(0).getResult().get().toString()); + } + + @Test + void testMultipleShadowsOfSameVariable() { + // Test when a variable is shadowed multiple times + Parameter param = Parameter.builder() + .name("Input") + .type(ParameterType.STRING) + .build(); + + List rules = Arrays.asList( + createRuleWithBinding("Input", "x", "temp", "https://1.com"), + createRuleWithBinding("Input", "y", "temp", "https://2.com"), + createRuleWithBinding("Input", "z", "temp", "https://3.com")); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(param).build()) + .rules(rules) + .version("1.0") + .build(); + + EndpointRuleSet result = SsaTransform.transform(original); + + List resultRules = result.getRules(); + assertEquals("temp_ssa_1", resultRules.get(0).getConditions().get(0).getResult().get().toString()); + assertEquals("temp_ssa_2", resultRules.get(1).getConditions().get(0).getResult().get().toString()); + assertEquals("temp_ssa_3", resultRules.get(2).getConditions().get(0).getResult().get().toString()); + } + + @Test + void testErrorRuleHandling() { + // Test that error rules are handled correctly + Parameter param = Parameter.builder() + .name("Input") + .type(ParameterType.STRING) + .build(); + + Condition cond = Condition.builder() + .fn(StringEquals.ofExpressions(Expression.of("Input"), Expression.of("error"))) + .result("hasError") + .build(); + + ErrorRule errorRule = ErrorRule.builder() + .conditions(Collections.singletonList(cond)) + .error(Expression.of("Error occurred")); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(param).build()) + .rules(Collections.singletonList(errorRule)) + .version("1.0") + .build(); + + EndpointRuleSet result = SsaTransform.transform(original); + + assertEquals(1, result.getRules().size()); + assertInstanceOf(ErrorRule.class, result.getRules().get(0)); + } + + @Test + void testTreeRuleHandling() { + // Test tree rules with unique variable names at each level + Parameter param = Parameter.builder() + .name("Region") + .type(ParameterType.STRING) + .build(); + + // Outer condition with one variable + Condition outerCond = Condition.builder() + .fn(StringEquals.ofExpressions(Expression.of("Region"), Expression.of("us-*"))) + .result("isUS") + .build(); + + // Inner rules with their own variables + EndpointRule innerRule1 = createRuleWithBinding("Region", "us-east-1", "isEast", "https://east.com"); + EndpointRule innerRule2 = createRuleWithBinding("Region", "us-west-2", "isWest", "https://west.com"); + + TreeRule treeRule = TreeRule.builder() + .conditions(Collections.singletonList(outerCond)) + .treeRule(innerRule1, innerRule2); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(param).build()) + .rules(Collections.singletonList(treeRule)) + .version("1.0") + .build(); + + EndpointRuleSet result = SsaTransform.transform(original); + + // Check structure is preserved + assertInstanceOf(TreeRule.class, result.getRules().get(0)); + TreeRule resultTree = (TreeRule) result.getRules().get(0); + assertEquals(2, resultTree.getRules().size()); + } + + @Test + void testParameterShadowingAttempt() { + // Test that attempting to shadow a parameter gets disambiguated + Parameter bucketParam = Parameter.builder() + .name("Bucket") + .type(ParameterType.STRING) + .build(); + + // Create a condition that assigns to "Bucket_shadow" to avoid direct conflict + Condition shadowingCond = Condition.builder() + .fn(StringEquals.ofExpressions(Expression.of("Bucket"), Expression.of("test"))) + .result("Bucket_shadow") + .build(); + + EndpointRule rule = EndpointRule.builder() + .conditions(Collections.singletonList(shadowingCond)) + .endpoint(endpoint("https://example.com")); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(bucketParam).build()) + .rules(Collections.singletonList(rule)) + .version("1.0") + .build(); + + EndpointRuleSet result = SsaTransform.transform(original); + + // Should handle without issues + EndpointRule resultRule = (EndpointRule) result.getRules().get(0); + assertEquals("Bucket_shadow", resultRule.getConditions().get(0).getResult().get().toString()); + } + + private static EndpointRule createRuleWithBinding(String param, String value, String resultVar, String url) { + Condition cond = Condition.builder() + .fn(StringEquals.ofExpressions(Expression.of(param), Expression.of(value))) + .result(resultVar) + .build(); + + return EndpointRule.builder() + .conditions(Collections.singletonList(cond)) + .endpoint(endpoint(url)); + } + + private static Expression expr(String value) { + return Literal.stringLiteral(Template.fromString(value)); + } + + private static Endpoint endpoint(String value) { + return Endpoint.builder().url(expr(value)).build(); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysisTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysisTest.java new file mode 100644 index 00000000000..ddb1505f7dc --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysisTest.java @@ -0,0 +1,469 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Map; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.Endpoint; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.utils.ListUtils; + +class VariableAnalysisTest { + + @Test + void testSimpleVariableBinding() { + // Rule with one variable binding + Condition condition = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(condition) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build()) + .addRule(rule) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertTrue(analysis.getInputParams().contains("Region")); + assertTrue(analysis.hasSingleBinding("hasRegion")); + assertEquals(0, analysis.getReferenceCount("hasRegion")); + } + + @Test + void testVariableReference() { + // Define and use a variable + Condition define = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Condition use = Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasRegion")), + Literal.of(true))) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(define, use) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertEquals(1, analysis.getReferenceCount("hasRegion")); + assertTrue(analysis.isReferencedOnce("hasRegion")); + + assertTrue(analysis.isSafeToInline("hasRegion")); + } + + @Test + void testMultipleBindings() { + // Same variable assigned in different branches + Rule rule1 = EndpointRule.builder() + .conditions(Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("x")) + .build()) + .endpoint(TestHelpers.endpoint("https://example1.com")); + + Rule rule2 = EndpointRule.builder() + .conditions(Condition.builder() + .fn(TestHelpers.isSet("Bucket")) + .result(Identifier.of("x")) + .build()) + .endpoint(TestHelpers.endpoint("https://example2.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Bucket").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .rules(ListUtils.of(rule1, rule2)) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertFalse(analysis.hasSingleBinding("x")); + assertTrue(analysis.hasMultipleBindings("x")); + + // Not safe to inline when multiple bindings exist + assertFalse(analysis.isSafeToInline("x")); + + // Should have different SSA names for different expressions + Map> mappings = analysis.getExpressionMappings(); + assertNotNull(mappings.get("x")); + assertEquals(2, mappings.get("x").size()); + } + + @Test + void testMultipleReferences() { + // Variable referenced multiple times + Condition define = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Condition use1 = Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasRegion")), + Literal.of(true))) + .build(); + + Condition use2 = Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasRegion")), + Literal.of(false))) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(define, use1, use2) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertEquals(2, analysis.getReferenceCount("hasRegion")); + assertFalse(analysis.isReferencedOnce("hasRegion")); + + assertFalse(analysis.isSafeToInline("hasRegion")); + } + + @Test + void testReferencesInEndpoint() { + // Variable used in endpoint URL - just use the Region parameter directly + Condition checkRegion = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .build(); + + Endpoint endpoint = Endpoint.builder() + .url(Literal.stringLiteral(Template.fromString("https://{Region}.example.com"))) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(checkRegion) + .endpoint(endpoint); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertEquals(2, analysis.getReferenceCount("Region")); + } + + @Test + void testReferencesInErrorRule() { + // First prove Region is set, then check if it's invalid + Condition checkRegion = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Condition checkInvalid = Condition.builder() + .fn(StringEquals.ofExpressions( + Expression.getReference(Identifier.of("Region")), + Literal.of("invalid"))) + .result(Identifier.of("isInvalid")) + .build(); + + // Use the Region value directly in the error message + Rule rule = ErrorRule.builder() + .conditions(checkRegion, checkInvalid) + .error(Literal.stringLiteral(Template.fromString("Invalid region: {Region}"))); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertEquals(3, analysis.getReferenceCount("Region")); + assertTrue(analysis.getReferenceCount("hasRegion") >= 0); + assertEquals(0, analysis.getReferenceCount("isInvalid")); + } + + @Test + void testNestedTreeRuleAnalysis() { + // Nested rules with variable bindings + Condition outerDefine = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Condition innerUse = Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasRegion")), + Literal.of(true))) + .build(); + + Rule innerRule = EndpointRule.builder() + .conditions(innerUse) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Rule treeRule = TreeRule.builder() + .conditions(outerDefine) + .treeRule(innerRule); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(treeRule) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertTrue(analysis.hasSingleBinding("hasRegion")); + assertEquals(1, analysis.getReferenceCount("hasRegion")); + assertTrue(analysis.isSafeToInline("hasRegion")); + } + + @Test + void testInputParametersIdentified() { + // Multiple input parameters + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Bucket").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("UseDualStack").type(ParameterType.BOOLEAN).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertEquals(3, analysis.getInputParams().size()); + assertTrue(analysis.getInputParams().contains("Region")); + assertTrue(analysis.getInputParams().contains("Bucket")); + assertTrue(analysis.getInputParams().contains("UseDualStack")); + } + + @Test + void testNoVariables() { + // Simple ruleset with no variable bindings + Rule rule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + // No variables bound + assertEquals(0, analysis.getReferenceCount("anyVar")); + assertFalse(analysis.hasSingleBinding("anyVar")); + assertFalse(analysis.isSafeToInline("anyVar")); + + // But Region is an input parameter + assertTrue(analysis.getInputParams().contains("Region")); + } + + @Test + void testSameExpressionDifferentVariableNames() { + // Same expression bound to different variable names in different rules + Rule rule1 = EndpointRule.builder() + .conditions(Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build()) + .endpoint(TestHelpers.endpoint("https://example1.com")); + + Rule rule2 = EndpointRule.builder() + .conditions(Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("regionExists")) + .build()) + .endpoint(TestHelpers.endpoint("https://example2.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .rules(ListUtils.of(rule1, rule2)) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + // Each variable has a single binding (same expression though) + assertTrue(analysis.hasSingleBinding("hasRegion")); + assertTrue(analysis.hasSingleBinding("regionExists")); + + // Neither is referenced after binding + assertEquals(0, analysis.getReferenceCount("hasRegion")); + assertEquals(0, analysis.getReferenceCount("regionExists")); + } + + @Test + void testDeeplyNestedTreeRules() { + // Multiple levels of tree rule nesting + Condition level3Define = Condition.builder() + .fn(TestHelpers.isSet("Bucket")) + .result(Identifier.of("hasBucket")) + .build(); + + Rule level3Rule = EndpointRule.builder() + .conditions(level3Define) + .endpoint(TestHelpers.endpoint("https://level3.com")); + + Condition level2Define = Condition.builder() + .fn(TestHelpers.isSet("Key")) + .result(Identifier.of("hasKey")) + .build(); + + Rule level2Rule = TreeRule.builder() + .conditions(level2Define) + .treeRule(level3Rule); + + Condition level1Define = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Condition level1Use = Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasRegion")), + Literal.of(true))) + .build(); + + Rule level1Rule = TreeRule.builder() + .conditions(level1Define, level1Use) + .treeRule(level2Rule); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Key").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Bucket").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(level1Rule) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertTrue(analysis.hasSingleBinding("hasRegion")); + assertTrue(analysis.hasSingleBinding("hasKey")); + assertTrue(analysis.hasSingleBinding("hasBucket")); + + assertEquals(1, analysis.getReferenceCount("hasRegion")); + assertEquals(0, analysis.getReferenceCount("hasKey")); + assertEquals(0, analysis.getReferenceCount("hasBucket")); + } + + @Test + void testUnreferencedVariable() { + // Variable that's defined but never used + Condition defineUnused = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("unused")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(defineUnused) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertTrue(analysis.hasSingleBinding("unused")); + assertEquals(0, analysis.getReferenceCount("unused")); + assertFalse(analysis.isSafeToInline("unused")); // Not safe because not referenced + } + + @Test + void testEmptyRuleSet() { + // Empty ruleset with just parameters + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertEquals(1, analysis.getInputParams().size()); + assertTrue(analysis.getInputParams().contains("Region")); + assertEquals(0, analysis.getReferenceCount("Region")); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/traits/BddTraitTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/traits/BddTraitTest.java new file mode 100644 index 00000000000..88c2aab778e --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/traits/BddTraitTest.java @@ -0,0 +1,102 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.traits; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.rulesengine.logic.bdd.Bdd; +import software.amazon.smithy.utils.ListUtils; + +public class BddTraitTest { + @Test + void testBddTraitSerialization() { + // Create a BddTrait with full context + Parameters params = Parameters.builder().build(); + Condition cond = Condition.builder().fn(TestHelpers.isSet("Region")).build(); + Rule endpoint = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); + + List results = new ArrayList<>(); + results.add(NoMatchRule.INSTANCE); + results.add(endpoint); + + Bdd bdd = createSimpleBdd(); + + EndpointBddTrait original = EndpointBddTrait.builder() + .parameters(params) + .conditions(ListUtils.of(cond)) + .results(results) + .bdd(bdd) + .build(); + + // Serialize to Node + Node node = original.toNode(); + assertTrue(node.isObjectNode()); + assertTrue(node.expectObjectNode().containsMember("parameters")); + assertTrue(node.expectObjectNode().containsMember("conditions")); + assertTrue(node.expectObjectNode().containsMember("results")); + + // Serialized should only have 1 result (the endpoint, not NoMatch) + int serializedResultCount = node.expectObjectNode() + .expectArrayMember("results") + .getElements() + .size(); + assertEquals(1, serializedResultCount); + + // Deserialize from Node + EndpointBddTrait restored = EndpointBddTrait.fromNode(node); + + assertEquals(original.getParameters(), restored.getParameters()); + assertEquals(original.getConditions().size(), restored.getConditions().size()); + assertEquals(original.getResults().size(), restored.getResults().size()); + assertEquals(original.getBdd().getRootRef(), restored.getBdd().getRootRef()); + assertEquals(original.getBdd().getConditionCount(), restored.getBdd().getConditionCount()); + assertEquals(original.getBdd().getResultCount(), restored.getBdd().getResultCount()); + + // Verify NoMatchRule was restored at index 0 + assertInstanceOf(NoMatchRule.class, restored.getResults().get(0)); + } + + private Bdd createSimpleBdd() { + int[] nodes = new int[] { + -1, + 1, + -1, // node 0: terminal + 0, + Bdd.RESULT_OFFSET + 1, + -1 // node 1 + }; + return new Bdd(2, 1, 2, 2, nodes); + } + + @Test + void testEmptyBddTrait() { + Parameters params = Parameters.builder().build(); + int[] nodes = new int[] {-1, 1, -1}; + Bdd bdd = new Bdd(-1, 0, 1, 1, nodes); + + EndpointBddTrait trait = EndpointBddTrait.builder() + .parameters(params) + .conditions(ListUtils.of()) + .results(ListUtils.of(NoMatchRule.INSTANCE)) + .bdd(bdd) + .build(); + + assertEquals(0, trait.getConditions().size()); + assertEquals(1, trait.getResults().size()); + assertEquals(-1, trait.getBdd().getRootRef()); // FALSE terminal + } +} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.errors new file mode 100644 index 00000000000..dc90845867d --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.errors @@ -0,0 +1 @@ +[ERROR] example#FizzBuzz: Coalesce requires rules engine version >= 1.1, but ruleset declares version 1.0 | RulesEngineVersion diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.smithy new file mode 100644 index 00000000000..11ca6b72297 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.smithy @@ -0,0 +1,45 @@ +$version: "2.0" + +namespace example + +use smithy.rules#clientContextParams +use smithy.rules#endpointRuleSet + +@clientContextParams( + foo: {type: "string", documentation: "a client string parameter"} +) +@endpointRuleSet({ + "version": "1.0", + "parameters": { + "foo": { + "type": "String", + "documentation": "docs" + } + }, + "rules": [ + { + "conditions": [ + { + "fn": "coalesce", + "argv": [ + {"ref": "foo"}, + "" + ], + "assign": "hi" + } + ], + "documentation": "base rule", + "endpoint": { + "url": "https://{hi}.amazonaws.com", + "headers": {} + }, + "type": "endpoint" + } + ] +}) +@suppress(["UnstableTrait.smithy"]) +service FizzBuzz { + operations: [GetResource] +} + +operation GetResource {} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/invalid-endpoint-uri.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/invalid-endpoint-uri.errors new file mode 100644 index 00000000000..95d19a87d72 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/invalid-endpoint-uri.errors @@ -0,0 +1,3 @@ +[WARNING] example#FizzBuzz: This shape applies a trait that is unstable: smithy.rules#clientContextParams | UnstableTrait.smithy.rules#clientContextParams +[WARNING] example#FizzBuzz: This shape applies a trait that is unstable: smithy.rules#endpointRuleSet | UnstableTrait.smithy.rules#endpointRuleSet +[ERROR] example#FizzBuzz: URI should start with `http://` or `https://` but the URI started with foo://example.com/ | RuleSetUri diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/invalid-endpoint-uri.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/invalid-endpoint-uri.smithy new file mode 100644 index 00000000000..7df32ec370e --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/invalid-endpoint-uri.smithy @@ -0,0 +1,47 @@ +$version: "2.0" + +namespace example + +use smithy.rules#clientContextParams +use smithy.rules#endpointRuleSet + +@clientContextParams( + bar: {type: "string", documentation: "a client string parameter"} +) +@endpointRuleSet({ + version: "1.0", + parameters: { + bar: { + type: "string", + documentation: "docs" + } + }, + rules: [ + { + "documentation": "lorem ipsum dolor", + "conditions": [ + { + "fn": "isSet", + "argv": [ + { + "ref": "bar" + } + ] + } + ], + "type": "endpoint", + "endpoint": { + "url": "foo://example.com/" + } + }, + { + "conditions": [], + "documentation": "error fallthrough", + "error": "endpoint error", + "type": "error" + } + ] +}) +service FizzBuzz { + version: "2022-01-01" +} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/coalesce-three-args.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/coalesce-three-args.errors new file mode 100644 index 00000000000..e69de29bb2d diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/coalesce-three-args.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/coalesce-three-args.smithy new file mode 100644 index 00000000000..54d4b6802d8 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/coalesce-three-args.smithy @@ -0,0 +1,93 @@ +$version: "2.0" + +namespace example + +use smithy.rules#clientContextParams +use smithy.rules#endpointRuleSet +use smithy.rules#endpointTests + +@clientContextParams( + bar: {type: "string", documentation: "a client string parameter"} + baz: {type: "string", documentation: "another client string parameter"} +) +@endpointRuleSet({ + version: "1.1", + parameters: { + bar: { + type: "string", + documentation: "docs" + } + baz: { + type: "string", + documentation: "docs" + } + }, + rules: [ + { + "documentation": "Template baz into URI when bar is set" + "conditions": [ + { + "fn": "coalesce" + "argv": [{"ref": "bar"}, {"ref": "baz"}, "oops"] + "assign": "hi" + } + ] + "endpoint": { + "url": "https://example.com/{hi}" + } + "type": "endpoint" + } + ] +}) +@endpointTests({ + "version": "1.0", + "testCases": [ + { + "params": { + "bar": "bar", + } + "operationInputs": [{ + "operationName": "GetThing" + }], + "expect": { + "endpoint": { + "url": "https://example.com/bar" + } + } + } + { + "params": { + "baz": "baz" + } + "operationInputs": [{ + "operationName": "GetThing" + }], + "expect": { + "endpoint": { + "url": "https://example.com/baz" + } + } + } + { + "params": {} + "operationInputs": [{ + "operationName": "GetThing" + }], + "expect": { + "endpoint": { + "url": "https://example.com/oops" + } + } + } + ] +}) +@suppress(["RuleSetParameter.TestCase.Unused"]) +@suppress(["UnstableTrait.smithy"]) +service FizzBuzz { + version: "2022-01-01", + operations: [GetThing] +} + +operation GetThing { + input := {} +} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/coalesce-with-boolean.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/coalesce-with-boolean.errors new file mode 100644 index 00000000000..e69de29bb2d diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/coalesce-with-boolean.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/coalesce-with-boolean.smithy new file mode 100644 index 00000000000..978cf1a735d --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/coalesce-with-boolean.smithy @@ -0,0 +1,118 @@ +$version: "2.0" + +namespace example + +use smithy.rules#clientContextParams +use smithy.rules#endpointRuleSet +use smithy.rules#endpointTests + +@clientContextParams( + bar: {type: "boolean", documentation: "a boolean value"} + baz: {type: "boolean", documentation: "another boolean value"} +) +@endpointRuleSet({ + version: "1.1", + parameters: { + bar: { + type: "boolean", + documentation: "docs" + } + baz: { + type: "boolean", + documentation: "docs" + } + }, + rules: [ + { + "documentation": "Template qux into URI when bar is set" + "conditions": [ + { + "fn": "coalesce" + "argv": [ + {"ref": "bar"} + {"ref": "baz"} + ] + } + ] + "endpoint": { + "url": "https://example.com" + } + "type": "endpoint" + } + { + "documentation": "Did not match" + "conditions": [] + "error": "Did not match" + "type": "error" + } + ] +}) +@endpointTests({ + "version": "1.0", + "testCases": [ + { + "params": { + "bar": false // not null, so don't even look at baz, and it's falsey condition, so go to error + "baz": false + } + "operationInputs": [{ + "operationName": "GetThing" + }], + "expect": { + "error": "Did not match" + } + } + { + "params": { + "bar": true // true != null, pick this, then the condition is truthy, so endpoint + "baz": false + } + "operationInputs": [{ + "operationName": "GetThing" + }], + "expect": { + "endpoint": { + "url": "https://example.com" + } + } + } + { + "params": { + // bar: null -- skip null values in coalesce + "baz": true // truthy, so resolve an endpoint + } + "operationInputs": [{ + "operationName": "GetThing" + }], + "expect": { + "endpoint": { + "url": "https://example.com" + } + } + } + { + "params": { + "bar": true // not null, truthy, so get an endpoint + // baz is null, but bar is not-null first + } + "operationInputs": [{ + "operationName": "GetThing" + }], + "expect": { + "endpoint": { + "url": "https://example.com" + } + } + } + ] +}) +@suppress(["RuleSetParameter.TestCase.Unused"]) +@suppress(["UnstableTrait.smithy"]) +service FizzBuzz { + version: "2022-01-01", + operations: [GetThing] +} + +operation GetThing { + input := {} +} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/substring.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/substring.smithy index 42d93b922d4..f76b3d279bc 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/substring.smithy +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/substring.smithy @@ -158,7 +158,7 @@ use smithy.rules#endpointTests "documentation": "unicode characters always return `None`", "params": { "TestCaseId": "1", - "Input": "abcdef\uD83D\uDC31" + "Input": "\uD83D\uDC31abcdef" }, "expect": { "error": "No tests matched" @@ -168,7 +168,7 @@ use smithy.rules#endpointTests "documentation": "non-ascii cause substring to always return `None`", "params": { "TestCaseId": "1", - "Input": "abcdef\u0080" + "Input": "ab\u0080cdef" }, "expect": { "error": "No tests matched" diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/invalid-rules/empty-rule.json5 b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/invalid-rules/empty-rule.json5 index d9f6d0bffb2..c508ad3f5cc 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/invalid-rules/empty-rule.json5 +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/invalid-rules/empty-rule.json5 @@ -1,7 +1,7 @@ // when parsing endpoint ruleset // while parsing rule // at invalid-rules/empty-rule.json5:15 -// Missing expected member `conditions`. +// Missing expected member `type`. { "version": "1.2", "parameters": { diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/minimal-ruleset.json b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/minimal-ruleset.json index 92b6d8292e5..3e554bce10b 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/minimal-ruleset.json +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/minimal-ruleset.json @@ -4,13 +4,13 @@ "Region": { "builtIn": "AWS::Region", "required": true, - "type": "String" + "type": "string" } }, "rules": [ { - "conditions": [], "documentation": "base rule", + "conditions": [], "endpoint": { "url": "https://{Region}.amazonaws.com", "properties": { diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.errors new file mode 100644 index 00000000000..aa03a54143f --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.errors @@ -0,0 +1 @@ +[ERROR] smithy.example#ValidBddService: Error creating trait `smithy.rules#endpointBdd`: Input byte array has wrong 4-byte ending unit | Model diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.smithy new file mode 100644 index 00000000000..7ce9f68fabd --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.smithy @@ -0,0 +1,36 @@ +$version: "2.0" + +namespace smithy.example + +use smithy.rules#endpointBdd + +@endpointBdd({ + version: "1.1" + parameters: { + Region: { + type: "string" + required: true + documentation: "The AWS region" + } + } + conditions: [ + { + fn: "isSet" + argv: [{ref: "Region"}] + } + ] + results: [ + { + type: "endpoint" + endpoint: { + url: "https://service.{Region}.amazonaws.com" + } + } + ] + nodes: "ABCD=" // invalid base64 + nodeCount: 3 + root: 1 +}) +service ValidBddService { + version: "2022-01-01" +} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.errors new file mode 100644 index 00000000000..7ef5d590fd7 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.errors @@ -0,0 +1 @@ +[ERROR] smithy.example#ValidBddService: Error creating trait `smithy.rules#endpointBdd`: Expected 36 bytes for 3 nodes, but got 2 | Model diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.smithy new file mode 100644 index 00000000000..b3e27e3b824 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.smithy @@ -0,0 +1,52 @@ +$version: "2.0" + +namespace smithy.example + +use smithy.rules#endpointBdd + +@endpointBdd({ + version: "1.1" + parameters: { + Region: { + type: "string" + required: true + documentation: "The AWS region" + } + UseFips: { + type: "boolean" + required: true + default: false + documentation: "Use FIPS endpoints" + } + } + conditions: [ + { + fn: "isSet" + argv: [{ref: "Region"}] + } + { + fn: "booleanEquals" + argv: [{ref: "UseFips"}, true] + } + ] + results: [ + { + type: "endpoint" + endpoint: { + url: "https://service.{Region}.amazonaws.com" + } + } + { + type: "endpoint" + endpoint: { + url: "https://service-fips.{Region}.amazonaws.com" + } + } + ] + nodes: "AQB" // bad data, valid base64 + nodeCount: 3 + root: 1 +}) +service ValidBddService { + version: "2022-01-01" +} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.errors new file mode 100644 index 00000000000..8703a938001 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.errors @@ -0,0 +1 @@ +[ERROR] smithy.example#InvalidRootRefService: Error creating trait `smithy.rules#endpointBdd`: Root reference cannot be complemented: -5 | Model diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.smithy new file mode 100644 index 00000000000..2976c0be659 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.smithy @@ -0,0 +1,22 @@ +$version: "2.0" + +namespace smithy.example + +use smithy.rules#endpointBdd + +@endpointBdd({ + version: "1.1" + parameters: {} + conditions: [] + results: [] + nodes: "" // Base64 encoded empty node array + root: -5 // Invalid negative root reference (only -1 is allowed for FALSE) + nodeCount: 0 +}) +service InvalidRootRefService { + version: "2022-01-01" + operations: [GetThing] +} + +@readonly +operation GetThing {} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.errors new file mode 100644 index 00000000000..19ee0420061 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.errors @@ -0,0 +1,2 @@ +[WARNING] smithy.example#ValidBddService: This shape applies a trait that is unstable: smithy.rules#endpointBdd | UnstableTrait.smithy.rules#endpointBdd +[WARNING] smithy.example#ValidBddService: This shape applies a trait that is unstable: smithy.rules#clientContextParams | UnstableTrait.smithy.rules#clientContextParams diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy new file mode 100644 index 00000000000..86d695be02c --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy @@ -0,0 +1,64 @@ +$version: "2.0" + +namespace smithy.example + +use smithy.rules#clientContextParams +use smithy.rules#endpointBdd + +@clientContextParams( + Region: {type: "string", documentation: "docs"} + UseFips: {type: "boolean", documentation: "docs"} +) +@endpointBdd({ + version: "1.1" + "parameters": { + "Region": { + "required": true, + "documentation": "The AWS region", + "type": "string" + }, + "UseFips": { + "required": true, + "default": false, + "documentation": "Use FIPS endpoints", + "type": "boolean" + } + }, + "conditions": [ + { + "fn": "booleanEquals", + "argv": [ + { + "ref": "UseFips" + }, + true + ] + } + ], + "results": [ + { + "conditions": [], + "endpoint": { + "url": "https://service-fips.{Region}.amazonaws.com", + "properties": {}, + "headers": {} + }, + "type": "endpoint" + }, + { + "conditions": [], + "endpoint": { + "url": "https://service.{Region}.amazonaws.com", + "properties": {}, + "headers": {} + }, + "type": "endpoint" + } + ], + "root": 2, + "nodeCount": 2, + "nodes": "/////wAAAAH/////AAAAAAX14QEF9eEC" +}) +service ValidBddService { + version: "2022-01-01" +} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/illegal-version.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/illegal-version.errors new file mode 100644 index 00000000000..e18a498effe --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/illegal-version.errors @@ -0,0 +1 @@ +[ERROR] smithy.example#ValidBddService: Error creating trait `smithy.rules#endpointBdd`: Rules engine version for endpointBdd trait must be >= 1.1 | Model diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/illegal-version.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/illegal-version.smithy new file mode 100644 index 00000000000..f37ef99b450 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/illegal-version.smithy @@ -0,0 +1,64 @@ +$version: "2.0" + +namespace smithy.example + +use smithy.rules#clientContextParams +use smithy.rules#endpointBdd + +@clientContextParams( + Region: {type: "string", documentation: "docs"} + UseFips: {type: "boolean", documentation: "docs"} +) +@endpointBdd({ + version: "1.0" + "parameters": { + "Region": { + "required": true, + "documentation": "The AWS region", + "type": "string" + }, + "UseFips": { + "required": true, + "default": false, + "documentation": "Use FIPS endpoints", + "type": "boolean" + } + }, + "conditions": [ + { + "fn": "booleanEquals", + "argv": [ + { + "ref": "UseFips" + }, + true + ] + } + ], + "results": [ + { + "conditions": [], + "endpoint": { + "url": "https://service-fips.{Region}.amazonaws.com", + "properties": {}, + "headers": {} + }, + "type": "endpoint" + }, + { + "conditions": [], + "endpoint": { + "url": "https://service.{Region}.amazonaws.com", + "properties": {}, + "headers": {} + }, + "type": "endpoint" + } + ], + "root": 2, + "nodeCount": 2, + "nodes": "/////wAAAAH/////AAAAAAX14QEF9eEC" +}) +service ValidBddService { + version: "2022-01-01" +} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/endpoint-tests-without-ruleset.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/endpoint-tests-without-ruleset.errors index 053df0d3352..788659c192c 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/endpoint-tests-without-ruleset.errors +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/endpoint-tests-without-ruleset.errors @@ -1 +1 @@ -[ERROR] smithy.example#InvalidService: Trait `smithy.rules#endpointTests` cannot be applied to `smithy.example#InvalidService`. This trait may only be applied to shapes that match the following selector: service[trait|smithy.rules#endpointRuleSet] | TraitTarget +[ERROR] smithy.example#InvalidService: Trait `smithy.rules#endpointTests` cannot be applied to `smithy.example#InvalidService`. This trait may only be applied to shapes that match the following selector: service :is([trait|smithy.rules#endpointRuleSet], [trait|smithy.rules#endpointBdd]) | TraitTarget