diff --git a/hcl2/api.py b/hcl2/api.py index 399ba929..1cec02a2 100644 --- a/hcl2/api.py +++ b/hcl2/api.py @@ -3,7 +3,7 @@ from lark.tree import Tree from hcl2.parser import parser, reconstruction_parser -from hcl2.transformer import DictTransformer +from hcl2.dict_transformer import DictTransformer from hcl2.reconstructor import HCLReconstructor, HCLReverseTransformer diff --git a/hcl2/const.py b/hcl2/const.py index 1d46f35a..1bd4a4ce 100644 --- a/hcl2/const.py +++ b/hcl2/const.py @@ -2,3 +2,4 @@ START_LINE_KEY = "__start_line__" END_LINE_KEY = "__end_line__" +IS_BLOCK = "__is_block__" diff --git a/hcl2/transformer.py b/hcl2/dict_transformer.py similarity index 99% rename from hcl2/transformer.py rename to hcl2/dict_transformer.py index 382092d6..64c58bcb 100644 --- a/hcl2/transformer.py +++ b/hcl2/dict_transformer.py @@ -277,6 +277,10 @@ def heredoc_template_trim(self, args: List) -> str: def new_line_or_comment(self, args: List) -> _DiscardType: return Discard + # def EQ(self, args: List): + # print("EQ", args) + # return args + def for_tuple_expr(self, args: List) -> str: args = self.strip_new_line_tokens(args) for_expr = " ".join([self.to_tf_inline(arg) for arg in args[1:-1]]) diff --git a/hcl2/parser.py b/hcl2/parser.py index 79d50122..3e524736 100644 --- a/hcl2/parser.py +++ b/hcl2/parser.py @@ -12,7 +12,7 @@ def parser() -> Lark: """Build standard parser for transforming HCL2 text into python structures""" return Lark.open( - "hcl2.lark", + "rule_transformer/hcl2.lark", parser="lalr", cache=str(PARSER_FILE), # Disable/Delete file to effect changes to the grammar rel_to=__file__, @@ -29,7 +29,7 @@ def reconstruction_parser() -> Lark: if necessary. """ return Lark.open( - "hcl2.lark", + "rule_transformer/hcl2.lark", parser="lalr", # Caching must be disabled to allow for reconstruction until lark-parser/lark#1472 is fixed: # diff --git a/hcl2/reconstructor.py b/hcl2/reconstructor.py index 7f957d7b..555edcf6 100644 --- a/hcl2/reconstructor.py +++ b/hcl2/reconstructor.py @@ -167,12 +167,17 @@ def _should_add_space(self, rule, current_terminal, is_block_label: bool = False if self._is_equals_sign(current_terminal): return True + if is_block_label: + pass + # print(rule, self._last_rule, current_terminal, self._last_terminal) + if is_block_label and isinstance(rule, Token) and rule.value == "string": if ( current_terminal == self._last_terminal == Terminal("DBLQUOTE") or current_terminal == Terminal("DBLQUOTE") - and self._last_terminal == Terminal("NAME") + and self._last_terminal == Terminal("IDENTIFIER") ): + # print("true") return True # if we're in a ternary or binary operator, add space around the operator diff --git a/hcl2/rule_transformer/__init__.py b/hcl2/rule_transformer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/hcl2/rule_transformer/deserializer.py b/hcl2/rule_transformer/deserializer.py new file mode 100644 index 00000000..7b834968 --- /dev/null +++ b/hcl2/rule_transformer/deserializer.py @@ -0,0 +1,279 @@ +import json +from functools import lru_cache +from typing import Any, TextIO, List + +from regex import regex + +from hcl2 import parses +from hcl2.const import IS_BLOCK +from hcl2.rule_transformer.rules.abstract import LarkElement, LarkRule +from hcl2.rule_transformer.rules.base import ( + BlockRule, + AttributeRule, + BodyRule, + StartRule, +) +from hcl2.rule_transformer.rules.containers import ( + TupleRule, + ObjectRule, + ObjectElemRule, + ObjectElemKeyExpressionRule, + ObjectElemKeyDotAccessor, + ObjectElemKeyRule, +) +from hcl2.rule_transformer.rules.expressions import ExprTermRule +from hcl2.rule_transformer.rules.literal_rules import ( + IdentifierRule, + IntLitRule, + FloatLitRule, +) +from hcl2.rule_transformer.rules.strings import ( + StringRule, + InterpolationRule, + StringPartRule, +) +from hcl2.rule_transformer.rules.tokens import ( + NAME, + EQ, + DBLQUOTE, + STRING_CHARS, + ESCAPED_INTERPOLATION, + INTERP_START, + RBRACE, + IntLiteral, + FloatLiteral, + RSQB, + LSQB, + COMMA, + DOT, + LBRACE, +) +from hcl2.rule_transformer.transformer import RuleTransformer +from hcl2.rule_transformer.utils import DeserializationOptions + + +class Deserializer: + def __init__(self, options=DeserializationOptions()): + self.options = options + + @property + @lru_cache + def _transformer(self) -> RuleTransformer: + return RuleTransformer() + + def load_python(self, value: Any) -> LarkElement: + return StartRule([self._deserialize(value)]) + + def loads(self, value: str) -> LarkElement: + return self.load_python(json.loads(value)) + + def load(self, file: TextIO) -> LarkElement: + return self.loads(file.read()) + + def _deserialize(self, value: Any) -> LarkElement: + if isinstance(value, dict): + if self._contains_block_marker(value): + elements = self._deserialize_block_elements(value) + return BodyRule(elements) + + return self._deserialize_object(value) + + if isinstance(value, list): + return self._deserialize_list(value) + + return self._deserialize_text(value) + + def _deserialize_block_elements(self, value: dict) -> List[LarkRule]: + children = [] + + for key, value in value.items(): + if self._is_block(value): + # this value is a list of blocks, iterate over each block and deserialize them + for block in value: + children.append(self._deserialize_block(key, block)) + else: + + # otherwise it's just an attribute + if key != IS_BLOCK: + children.append(self._deserialize_attribute(key, value)) + + return children + + def _deserialize_text(self, value) -> LarkRule: + try: + int_val = int(value) + return IntLitRule([IntLiteral(int_val)]) + except ValueError: + pass + + try: + float_val = float(value) + return FloatLitRule([FloatLiteral(float_val)]) + except ValueError: + pass + + if isinstance(value, str): + if value.startswith('"') and value.endswith('"'): + return self._deserialize_string(value) + + if self._is_expression(value): + return self._deserialize_expression(value) + + return self._deserialize_identifier(value) + + elif isinstance(value, bool): + return self._deserialize_identifier(str(value).lower()) + + return self._deserialize_identifier(str(value)) + + def _deserialize_identifier(self, value: str) -> IdentifierRule: + return IdentifierRule([NAME(value)]) + + def _deserialize_string(self, value: str) -> StringRule: + result = [] + + pattern = regex.compile(r"(\${1,2}\{(?:[^{}]|(?R))*\})") + parts = [part for part in pattern.split(value) if part != ""] + # e.g. 'aaa$${bbb}ccc${"ddd-${eee}"}' -> ['aaa', '$${bbb}', 'ccc', '${"ddd-${eee}"}'] + # 'aa-${"bb-${"cc-${"dd-${5 + 5}"}"}"}' -> ['aa-', '${"bb-${"cc-${"dd-${5 + 5}"}"}"}'] + + for part in parts: + if part == '"': + continue + + if part.startswith('"'): + part = part[1:] + if part.endswith('"'): + part = part[:-1] + + e = self._deserialize_string_part(part) + result.append(e) + + return StringRule([DBLQUOTE(), *result, DBLQUOTE()]) + + def _deserialize_string_part(self, value: str) -> StringPartRule: + if value.startswith("$${") and value.endswith("}"): + return StringPartRule([ESCAPED_INTERPOLATION(value)]) + + if value.startswith("${") and value.endswith("}"): + return StringPartRule( + [ + InterpolationRule( + [INTERP_START(), self._deserialize_expression(value), RBRACE()] + ) + ] + ) + + return StringPartRule([STRING_CHARS(value)]) + + def _deserialize_expression(self, value: str) -> ExprTermRule: + """Deserialize an expression string into an ExprTermRule.""" + # instead of processing expression manually and trying to recognize what kind of expression it is, + # turn it into HCL2 code and parse it with lark: + + # unwrap from ${ and } + value = value[2:-1] + # create HCL2 snippet + value = f"temp = {value}" + # parse the above + parsed_tree = parses(value) + # transform parsed tree into LarkElement tree + rules_tree = self._transformer.transform(parsed_tree) + # extract expression from the tree + return rules_tree.body.children[0].expression + + def _deserialize_block(self, first_label: str, value: dict) -> BlockRule: + """Deserialize a block by extracting labels and body""" + labels = [first_label] + body = value + + # Keep peeling off single-key layers until we hit the body (dict with IS_BLOCK) + while isinstance(body, dict) and not body.get(IS_BLOCK): + non_block_keys = [k for k in body.keys() if k != IS_BLOCK] + if len(non_block_keys) == 1: + # This is another label level + label = non_block_keys[0] + labels.append(label) + body = body[label] + else: + # Multiple keys = this is the body + break + + return BlockRule( + [*[self._deserialize(label) for label in labels], self._deserialize(body)] + ) + + def _deserialize_attribute(self, name: str, value: Any) -> AttributeRule: + children = [ + self._deserialize_identifier(name), + EQ(), + ExprTermRule([self._deserialize(value)]), + ] + return AttributeRule(children) + + def _deserialize_list(self, value: List) -> TupleRule: + children = [] + for element in value: + deserialized = self._deserialize(element) + if not isinstance(deserialized, ExprTermRule): + # whatever an element of the list is, it has to be nested inside ExprTermRule + deserialized = ExprTermRule([deserialized]) + children.append(deserialized) + children.append(COMMA()) + + return TupleRule([LSQB(), *children, RSQB()]) + + def _deserialize_object(self, value: dict) -> ObjectRule: + children = [] + for key, value in value.items(): + children.append(self._deserialize_object_elem(key, value)) + return ObjectRule([LBRACE(), *children, RBRACE()]) + + def _deserialize_object_elem(self, key: str, value: Any) -> ObjectElemRule: + if self._is_expression(key): + key = ObjectElemKeyExpressionRule([self._deserialize_expression(key)]) + elif "." in key: + parts = key.split(".") + children = [] + for part in parts: + children.append(self._deserialize_identifier(part)) + children.append(DOT()) + key = ObjectElemKeyDotAccessor(children[:-1]) # without the last comma + else: + key = self._deserialize_text(key) + + return ObjectElemRule( + [ + ObjectElemKeyRule([key]), + EQ(), + ExprTermRule([self._deserialize_text(value)]), + ] + ) + + def _is_expression(self, value: str) -> bool: + return value.startswith("${") and value.endswith("}") + + def _is_block(self, value: Any) -> bool: + """Simple check: if it's a list containing dicts with IS_BLOCK markers""" + if not isinstance(value, list) or len(value) == 0: + return False + + # Check if any item in the list has IS_BLOCK marker (directly or nested) + for item in value: + if isinstance(item, dict) and self._contains_block_marker(item): + return True + + return False + + def _contains_block_marker(self, obj: dict) -> bool: + """Recursively check if a dict contains IS_BLOCK marker anywhere""" + if obj.get(IS_BLOCK): + return True + for value in obj.values(): + if isinstance(value, dict) and self._contains_block_marker(value): + return True + if isinstance(value, list): + for element in value: + if self._contains_block_marker(element): + return True + return False diff --git a/hcl2/rule_transformer/editor.py b/hcl2/rule_transformer/editor.py new file mode 100644 index 00000000..9efce08f --- /dev/null +++ b/hcl2/rule_transformer/editor.py @@ -0,0 +1,77 @@ +import dataclasses +from copy import copy, deepcopy +from typing import List, Optional, Set, Tuple + +from hcl2.rule_transformer.rules.abstract import LarkRule +from hcl2.rule_transformer.rules.base import BlockRule, StartRule + + +@dataclasses.dataclass +class TreePathElement: + + name: str + index: int = 0 + + +@dataclasses.dataclass +class TreePath: + + elements: List[TreePathElement] = dataclasses.field(default_factory=list) + + @classmethod + def build(cls, elements: List[Tuple[str, Optional[int]] | str]): + results = [] + for element in elements: + if isinstance(element, tuple): + if len(element) == 1: + result = TreePathElement(element[0], 0) + else: + result = TreePathElement(*element) + else: + result = TreePathElement(element, 0) + + results.append(result) + + return cls(results) + + def __iter__(self): + return self.elements.__iter__() + + def __len__(self): + return self.elements.__len__() + + +class Editor: + def __init__(self, rules_tree: LarkRule): + self.rules_tree = rules_tree + + @classmethod + def _find_one(cls, rules_tree: LarkRule, path_element: TreePathElement) -> LarkRule: + return cls._find_all(rules_tree, path_element.name)[path_element.index] + + @classmethod + def _find_all(cls, rules_tree: LarkRule, rule_name: str) -> List[LarkRule]: + children = [] + print("rule", rules_tree) + print("rule children", rules_tree.children) + for child in rules_tree.children: + if isinstance(child, LarkRule) and child.lark_name() == rule_name: + children.append(child) + + return children + + def find_by_path(self, path: TreePath, rule_name: str) -> List[LarkRule]: + path = deepcopy(path.elements) + + current_rule = self.rules_tree + while len(path) > 0: + current_path, *path = path + print(current_path, path) + current_rule = self._find_one(current_rule, current_path) + + return self._find_all(current_rule, rule_name) + + # def visit(self, path: TreePath) -> "Editor": + # + # while len(path) > 1: + # current = diff --git a/hcl2/rule_transformer/hcl2.lark b/hcl2/rule_transformer/hcl2.lark new file mode 100644 index 00000000..3f8d913e --- /dev/null +++ b/hcl2/rule_transformer/hcl2.lark @@ -0,0 +1,163 @@ +// ============================================================================ +// Terminals +// ============================================================================ + +// Whitespace and Comments +NL_OR_COMMENT: /\n[ \t]*/ | /#.*\n/ | /\/\/.*\n/ | /\/\*(.|\n)*?(\*\/)/ + +// Keywords +IF : "if" +IN : "in" +FOR : "for" +FOR_EACH : "for_each" + + +// Literals +NAME : /[a-zA-Z_][a-zA-Z0-9_-]*/ +ESCAPED_INTERPOLATION.2: /\$\$\{[^}]*\}/ +STRING_CHARS.1: /(?:(?!\$\$\{)(?!\$\{)[^"\\]|\\.|(?:\$(?!\$?\{)))+/ +DECIMAL : "0".."9" +NEGATIVE_DECIMAL : "-" DECIMAL +EXP_MARK : ("e" | "E") ("+" | "-")? DECIMAL+ +INT_LITERAL: NEGATIVE_DECIMAL? DECIMAL+ +FLOAT_LITERAL: (NEGATIVE_DECIMAL? DECIMAL+ | NEGATIVE_DECIMAL+) "." DECIMAL+ (EXP_MARK)? + | (NEGATIVE_DECIMAL? DECIMAL+ | NEGATIVE_DECIMAL+) (EXP_MARK) + +// Operators +BINARY_OP : DOUBLE_EQ | NEQ | LT | GT | LEQ | GEQ | MINUS | ASTERISK | SLASH | PERCENT | DOUBLE_AMP | DOUBLE_PIPE | PLUS +DOUBLE_EQ : "==" +NEQ : "!=" +LT : "<" +GT : ">" +LEQ : "<=" +GEQ : ">=" +MINUS : "-" +ASTERISK : "*" +SLASH : "/" +PERCENT : "%" +DOUBLE_AMP : "&&" +DOUBLE_PIPE : "||" +PLUS : "+" +NOT : "!" +QMARK : "?" + +// Punctuation +LPAR : "(" +RPAR : ")" +LBRACE : "{" +RBRACE : "}" +LSQB : "[" +RSQB : "]" +COMMA : "," +DOT : "." +EQ : /[ \t]*=(?!=|>)/ +COLON : ":" +DBLQUOTE : "\"" + +// Interpolation +INTERP_START : "${" + +// Splat Operators +ATTR_SPLAT : ".*" +FULL_SPLAT_START : "[*]" + +// Special Operators +FOR_OBJECT_ARROW : "=>" +ELLIPSIS : "..." +COLONS: "::" + +// Heredocs +HEREDOC_TEMPLATE : /<<(?P[a-zA-Z][a-zA-Z0-9._-]+)\n?(?:.|\n)*?\n\s*(?P=heredoc)\n/ +HEREDOC_TEMPLATE_TRIM : /<<-(?P[a-zA-Z][a-zA-Z0-9._-]+)\n?(?:.|\n)*?\n\s*(?P=heredoc_trim)\n/ + +// Ignore whitespace (but not newlines, as they're significant in HCL) +%ignore /[ \t]+/ + +// ============================================================================ +// Rules +// ============================================================================ + +// Top-level structure +start : body + +// Body and basic constructs +body : (new_line_or_comment? (attribute | block))* new_line_or_comment? +attribute : identifier EQ expression +block : identifier (identifier | string)* new_line_or_comment? LBRACE body RBRACE + +// Whitespace and comments +new_line_or_comment: ( NL_OR_COMMENT )+ + +// Basic literals and identifiers +identifier : NAME +keyword: IN | FOR | IF | FOR_EACH +int_lit: INT_LITERAL +float_lit: FLOAT_LITERAL +string: DBLQUOTE string_part* DBLQUOTE +string_part: STRING_CHARS + | ESCAPED_INTERPOLATION + | interpolation + +// Expressions +?expression : expr_term | operation | conditional +interpolation: INTERP_START expression RBRACE +conditional : expression QMARK new_line_or_comment? expression new_line_or_comment? COLON new_line_or_comment? expression + +// Operations +?operation : unary_op | binary_op +!unary_op : (MINUS | NOT) expr_term +binary_op : expression binary_term new_line_or_comment? +binary_term : binary_operator new_line_or_comment? expression +!binary_operator : BINARY_OP + +// Expression terms +expr_term : LPAR new_line_or_comment? expression new_line_or_comment? RPAR + | float_lit + | int_lit + | string + | tuple + | object + | identifier + | function_call + | heredoc_template + | heredoc_template_trim + | index_expr_term + | get_attr_expr_term + | attr_splat_expr_term + | full_splat_expr_term + | for_tuple_expr + | for_object_expr + +// Collections +tuple : LSQB new_line_or_comment? (expression new_line_or_comment? COMMA new_line_or_comment?)* (expression new_line_or_comment? COMMA? new_line_or_comment?)? RSQB +object : LBRACE new_line_or_comment? ((object_elem | (object_elem new_line_or_comment? COMMA)) new_line_or_comment?)* RBRACE +object_elem : object_elem_key ( EQ | COLON ) expression +object_elem_key : float_lit | int_lit | identifier | string | object_elem_key_dot_accessor | object_elem_key_expression +object_elem_key_expression : LPAR expression RPAR +object_elem_key_dot_accessor : identifier (DOT identifier)+ + +// Heredocs +heredoc_template : HEREDOC_TEMPLATE +heredoc_template_trim : HEREDOC_TEMPLATE_TRIM + +// Functions +function_call : identifier (COLONS identifier COLONS identifier)? LPAR new_line_or_comment? arguments? new_line_or_comment? RPAR +arguments : (expression (new_line_or_comment? COMMA new_line_or_comment? expression)* (COMMA | ELLIPSIS)? new_line_or_comment?) + +// Indexing and attribute access +index_expr_term : expr_term index +get_attr_expr_term : expr_term get_attr +attr_splat_expr_term : expr_term attr_splat +full_splat_expr_term : expr_term full_splat +?index : braces_index | short_index +braces_index : LSQB new_line_or_comment? expression new_line_or_comment? RSQB +short_index : DOT INT_LITERAL +get_attr : DOT identifier +attr_splat : ATTR_SPLAT (get_attr | index)* +full_splat : FULL_SPLAT_START (get_attr | index)* + +// For expressions +!for_tuple_expr : LSQB new_line_or_comment? for_intro new_line_or_comment? expression new_line_or_comment? for_cond? new_line_or_comment? RSQB +!for_object_expr : LBRACE new_line_or_comment? for_intro new_line_or_comment? expression FOR_OBJECT_ARROW new_line_or_comment? expression new_line_or_comment? ELLIPSIS? new_line_or_comment? for_cond? new_line_or_comment? RBRACE +!for_intro : FOR new_line_or_comment? identifier (COMMA identifier new_line_or_comment?)? new_line_or_comment? IN new_line_or_comment? expression new_line_or_comment? COLON new_line_or_comment? +!for_cond : IF new_line_or_comment? expression diff --git a/hcl2/rule_transformer/json.py b/hcl2/rule_transformer/json.py new file mode 100644 index 00000000..647b6683 --- /dev/null +++ b/hcl2/rule_transformer/json.py @@ -0,0 +1,12 @@ +from json import JSONEncoder +from typing import Any + +from hcl2.rule_transformer.rules.abstract import LarkRule + + +class LarkEncoder(JSONEncoder): + def default(self, obj: Any): + if isinstance(obj, LarkRule): + return obj.serialize() + else: + return super().default(obj) diff --git a/hcl2/rule_transformer/processor.py b/hcl2/rule_transformer/processor.py new file mode 100644 index 00000000..b854aff5 --- /dev/null +++ b/hcl2/rule_transformer/processor.py @@ -0,0 +1,258 @@ +from copy import copy, deepcopy +from typing import ( + List, + Optional, + Union, + Callable, + Any, + Tuple, + Generic, + TypeVar, + cast, + Generator, +) + +from hcl2.rule_transformer.rules.abstract import LarkRule, LarkElement +from hcl2.rule_transformer.rules.base import BlockRule, AttributeRule +from hcl2.rule_transformer.rules.whitespace import NewLineOrCommentRule + +T = TypeVar("T", bound=LarkRule) + + +class RulesProcessor(Generic[T]): + """""" + + @classmethod + def _traverse( + cls, + node: T, + predicate: Callable[[T], bool], + current_depth: int = 0, + max_depth: Optional[int] = None, + ) -> List["RulesProcessor"]: + + results = [] + + if predicate(node): + results.append(cls(node)) + + if max_depth is not None and current_depth >= max_depth: + return results + + for child in node.children: + if child is None or not isinstance(child, LarkRule): + continue + + child_results = cls._traverse( + child, + predicate, + current_depth + 1, + max_depth, + ) + results.extend(child_results) + + return results + + def __init__(self, node: LarkRule): + self.node = node + + @property + def siblings(self): + if self.node.parent is None: + return None + return self.node.parent.children + + @property + def next_siblings(self): + if self.node.parent is None: + return None + return self.node.parent.children[self.node.index + 1 :] + + @property + def previous_siblings(self): + if self.node.parent is None: + return None + return self.node.parent.children[: self.node.index - 1] + + def walk(self) -> Generator[Tuple["RulesProcessor", List["RulesProcessor"]]]: + child_processors = [self.__class__(child) for child in self.node.children] + yield self, child_processors + for processor in child_processors: + if isinstance(processor.node, LarkRule): + for result in processor.walk(): + yield result + + def find_block( + self, + labels: List[str], + exact_match: bool = True, + max_depth: Optional[int] = None, + ) -> "RulesProcessor[BlockRule]": + return self.find_blocks(labels, exact_match, max_depth)[0] + + def find_blocks( + self, + labels: List[str], + exact_match: bool = True, + max_depth: Optional[int] = None, + ) -> List["RulesProcessor[BlockRule]"]: + """ + Find blocks by their labels. + + Args: + labels: List of label strings to match + exact_match: If True, all labels must match exactly. If False, labels can be a subset. + max_depth: Maximum depth to search + + Returns: + ... + """ + + def block_predicate(node: LarkRule) -> bool: + if not isinstance(node, BlockRule): + return False + + node_labels = [label.serialize() for label in node.labels] + + if exact_match: + return node_labels == labels + else: + # Check if labels is a prefix of node_labels + if len(labels) > len(node_labels): + return False + return node_labels[: len(labels)] == labels + + return cast( + List[RulesProcessor[BlockRule]], + self._traverse(self.node, block_predicate, max_depth=max_depth), + ) + + def attribute( + self, name: str, max_depth: Optional[int] = None + ) -> "RulesProcessor[AttributeRule]": + return self.find_attributes(name, max_depth)[0] + + def find_attributes( + self, name: str, max_depth: Optional[int] = None + ) -> List["RulesProcessor[AttributeRule]"]: + """ + Find attributes by their identifier name. + + Args: + name: Attribute name to search for + max_depth: Maximum depth to search + + Returns: + List of TreePath objects for matching attributes + """ + + def attribute_predicate(node: LarkRule) -> bool: + if not isinstance(node, AttributeRule): + return False + return node.identifier.serialize() == name + + return self._traverse(self.node, attribute_predicate, max_depth=max_depth) + + def rule(self, rule_name: str, max_depth: Optional[int] = None): + return self.find_rules(rule_name, max_depth)[0] + + def find_rules( + self, rule_name: str, max_depth: Optional[int] = None + ) -> List["RulesProcessor"]: + """ + Find all rules of a specific type. + + Args: + rule_name: Name of the rule type to find + max_depth: Maximum depth to search + + Returns: + List of TreePath objects for matching rules + """ + + def rule_predicate(node: LarkRule) -> bool: + return node.lark_name() == rule_name + + return self._traverse(self.node, rule_predicate, max_depth=max_depth) + + def find_by_predicate( + self, predicate: Callable[[LarkRule], bool], max_depth: Optional[int] = None + ) -> List["RulesProcessor"]: + """ + Find all rules matching a custom predicate. + + Args: + predicate: Function that returns True for nodes to collect + max_depth: Maximum depth to search + + Returns: + List of TreePath objects for matching rules + """ + return self._traverse(self.node, predicate, max_depth) + + # Convenience methods + def get_all_blocks(self, max_depth: Optional[int] = None) -> List: + """Get all blocks in the tree.""" + return self.find_rules("block", max_depth) + + def get_all_attributes( + self, max_depth: Optional[int] = None + ) -> List["RulesProcessor"]: + """Get all attributes in the tree.""" + return self.find_rules("attribute", max_depth) + + def previous(self, skip_new_line: bool = True) -> Optional["RulesProcessor"]: + """Get the next sibling node.""" + if self.node.parent is None: + return None + + for sibling in reversed(self.previous_siblings): + if sibling is not None and isinstance(sibling, LarkRule): + if skip_new_line and isinstance(sibling, NewLineOrCommentRule): + continue + return self.__class__(sibling) + + def next(self, skip_new_line: bool = True) -> Optional["RulesProcessor"]: + """Get the next sibling node.""" + if self.node.parent is None: + return None + + for sibling in self.next_siblings: + if sibling is not None and isinstance(sibling, LarkRule): + if skip_new_line and isinstance(sibling, NewLineOrCommentRule): + continue + return self.__class__(sibling) + + def append_child( + self, new_node: LarkRule, indentation: bool = True + ) -> "RulesProcessor": + children = self.node.children + if indentation: + if isinstance(children[-1], NewLineOrCommentRule): + children.pop() + children.append(NewLineOrCommentRule.from_string("\n ")) + + new_node = deepcopy(new_node) + new_node.set_parent(self.node) + new_node.set_index(len(children)) + children.append(new_node) + return self.__class__(new_node) + + def replace(self, new_node: LarkRule) -> "RulesProcessor": + new_node = deepcopy(new_node) + + self.node.parent.children.pop(self.node.index) + self.node.parent.children.insert(self.node.index, new_node) + new_node.set_parent(self.node.parent) + new_node.set_index(self.node.index) + return self.__class__(new_node) + + # def insert_before(self, new_node: LarkRule) -> bool: + # """Insert a new node before this one.""" + # if self.parent is None or self.parent_index < 0: + # return False + # + # try: + # self.parent.children.insert(self.parent_index, new_node) + # except (IndexError, AttributeError): + # return False diff --git a/hcl2/rule_transformer/rules/__init__.py b/hcl2/rule_transformer/rules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/hcl2/rule_transformer/rules/abstract.py b/hcl2/rule_transformer/rules/abstract.py new file mode 100644 index 00000000..33dcc9ca --- /dev/null +++ b/hcl2/rule_transformer/rules/abstract.py @@ -0,0 +1,107 @@ +from abc import ABC, abstractmethod +from typing import Any, Union, List, Optional, Tuple, Callable + +from lark import Token, Tree +from lark.exceptions import VisitError +from lark.tree import Meta + +from hcl2.rule_transformer.utils import SerializationOptions, SerializationContext + + +class LarkElement(ABC): + @staticmethod + @abstractmethod + def lark_name() -> str: + raise NotImplementedError() + + def __init__(self, index: int = -1, parent: "LarkElement" = None): + self._index = index + self._parent = parent + + def set_index(self, i: int): + self._index = i + + def set_parent(self, node: "LarkElement"): + self._parent = node + + @abstractmethod + def to_lark(self) -> Any: + raise NotImplementedError() + + @abstractmethod + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + raise NotImplementedError() + + +class LarkToken(LarkElement, ABC): + def __init__(self, value: Union[str, int, float]): + self._value = value + super().__init__() + + @property + @abstractmethod + def serialize_conversion(self) -> Callable: + raise NotImplementedError() + + @property + def value(self): + return self._value + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + return self.serialize_conversion(self.value) + + def to_lark(self) -> Token: + return Token(self.lark_name(), self.value) + + def __str__(self) -> str: + return str(self._value) + + def __repr__(self) -> str: + return f"" + + +class LarkRule(LarkElement, ABC): + @abstractmethod + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + raise NotImplementedError() + + @property + def children(self) -> List[LarkElement]: + return self._children + + @property + def parent(self): + return self._parent + + @property + def index(self): + return self._index + + def to_lark(self) -> Tree: + result_children = [] + for child in self._children: + if child is None: + continue + + result_children.append(child.to_lark()) + + return Tree(self.lark_name(), result_children, meta=self._meta) + + def __init__(self, children: List[LarkElement], meta: Optional[Meta] = None): + super().__init__() + self._children = children + self._meta = meta + + for index, child in enumerate(children): + if child is not None: + child.set_index(index) + child.set_parent(self) + + def __repr__(self): + return f"" diff --git a/hcl2/rule_transformer/rules/base.py b/hcl2/rule_transformer/rules/base.py new file mode 100644 index 00000000..5c8468d4 --- /dev/null +++ b/hcl2/rule_transformer/rules/base.py @@ -0,0 +1,158 @@ +from collections import defaultdict +from typing import Tuple, Any, List, Union, Optional + +from lark.tree import Meta + +from hcl2.const import IS_BLOCK +from hcl2.rule_transformer.rules.abstract import LarkRule, LarkToken +from hcl2.rule_transformer.rules.expressions import ExpressionRule +from hcl2.rule_transformer.rules.literal_rules import IdentifierRule +from hcl2.rule_transformer.rules.strings import StringRule +from hcl2.rule_transformer.rules.tokens import NAME, EQ + +from hcl2.rule_transformer.rules.whitespace import NewLineOrCommentRule +from hcl2.rule_transformer.utils import SerializationOptions, SerializationContext + + +class AttributeRule(LarkRule): + _children: Tuple[ + NAME, + EQ, + ExpressionRule, + ] + + @staticmethod + def lark_name() -> str: + return "attribute" + + @property + def identifier(self) -> NAME: + return self._children[0] + + @property + def expression(self) -> ExpressionRule: + return self._children[2] + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + return {self.identifier.serialize(options): self.expression.serialize(options)} + + +class BodyRule(LarkRule): + + _children: List[ + Union[ + NewLineOrCommentRule, + AttributeRule, + "BlockRule", + ] + ] + + @staticmethod + def lark_name() -> str: + return "body" + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + blocks: List[BlockRule] = [] + attributes: List[AttributeRule] = [] + comments = [] + inline_comments = [] + + for child in self._children: + + if isinstance(child, BlockRule): + blocks.append(child) + + if isinstance(child, AttributeRule): + attributes.append(child) + # collect in-line comments from attribute assignments, expressions etc + inline_comments.extend(child.expression.inline_comments()) + + if isinstance(child, NewLineOrCommentRule): + child_comments = child.to_list() + if child_comments: + comments.extend(child_comments) + + result = {} + + for attribute in attributes: + result.update(attribute.serialize(options)) + + result_blocks = defaultdict(list) + for block in blocks: + name = block.labels[0].serialize(options) + if name in result.keys(): + raise RuntimeError(f"Attribute {name} is already defined.") + result_blocks[name].append(block.serialize(options)) + + result.update(**result_blocks) + + if options.with_comments: + if comments: + result["__comments__"] = comments + if inline_comments: + result["__inline_comments__"] = inline_comments + + return result + + +class StartRule(LarkRule): + + _children: Tuple[BodyRule] + + @property + def body(self) -> BodyRule: + return self._children[0] + + @staticmethod + def lark_name() -> str: + return "start" + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + return self.body.serialize(options) + + +class BlockRule(LarkRule): + + _children: Tuple[ + IdentifierRule, + Optional[Union[IdentifierRule, StringRule]], + BodyRule, + ] + + def __init__(self, children, meta: Optional[Meta] = None): + super().__init__(children, meta) + + *self._labels, self._body = [ + child for child in children if not isinstance(child, LarkToken) + ] + + @staticmethod + def lark_name() -> str: + return "block" + + @property + def labels(self) -> List[NAME]: + return list(filter(lambda label: label is not None, self._labels)) + + @property + def body(self) -> BodyRule: + return self._body + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + result = self._body.serialize(options) + if options.explicit_blocks: + result.update({IS_BLOCK: True}) + + labels = self._labels + for label in reversed(labels[1:]): + result = {label.serialize(options): result} + + return result diff --git a/hcl2/rule_transformer/rules/containers.py b/hcl2/rule_transformer/rules/containers.py new file mode 100644 index 00000000..11ac0f5e --- /dev/null +++ b/hcl2/rule_transformer/rules/containers.py @@ -0,0 +1,204 @@ +import json +from typing import Tuple, List, Optional, Union, Any + +from hcl2.rule_transformer.rules.abstract import LarkRule +from hcl2.rule_transformer.rules.expressions import ExpressionRule +from hcl2.rule_transformer.rules.literal_rules import ( + FloatLitRule, + IntLitRule, + IdentifierRule, +) +from hcl2.rule_transformer.rules.strings import StringRule +from hcl2.rule_transformer.rules.tokens import ( + COLON, + EQ, + LBRACE, + COMMA, + RBRACE, LSQB, RSQB, LPAR, RPAR, DOT, +) +from hcl2.rule_transformer.rules.whitespace import ( + NewLineOrCommentRule, + InlineCommentMixIn, +) +from hcl2.rule_transformer.utils import SerializationOptions, SerializationContext, to_dollar_string + + +class TupleRule(InlineCommentMixIn): + + _children: Tuple[ + LSQB, + Optional[NewLineOrCommentRule], + Tuple[ + ExpressionRule, + Optional[NewLineOrCommentRule], + COMMA, + Optional[NewLineOrCommentRule], + ... + ], + ExpressionRule, + Optional[NewLineOrCommentRule], + Optional[COMMA], + Optional[NewLineOrCommentRule], + RSQB, + ] + + @staticmethod + def lark_name() -> str: + return "tuple" + + @property + def elements(self) -> List[ExpressionRule]: + return [ + child for child in self.children[1:-1] if isinstance(child, ExpressionRule) + ] + + def serialize(self, options = SerializationOptions(), context = SerializationContext()) -> Any: + if not options.wrap_tuples: + return [element.serialize(options, context) for element in self.elements] + + with context.modify(inside_dollar_string=True): + result = f"[{", ".join( + str(element.serialize(options, context)) for element in self.elements + )}]" + + if not context.inside_dollar_string: + result = to_dollar_string(result) + + return result + + +class ObjectElemKeyRule(LarkRule): + + key_T = Union[FloatLitRule, IntLitRule, IdentifierRule, StringRule] + + _children: Tuple[key_T] + + @staticmethod + def lark_name() -> str: + return "object_elem_key" + + @property + def value(self) -> key_T: + return self._children[0] + + def serialize(self, options = SerializationOptions(), context = SerializationContext()) -> Any: + return self.value.serialize(options, context) + + +class ObjectElemKeyExpressionRule(LarkRule): + + _children: Tuple[ + LPAR, + ExpressionRule, + RPAR, + ] + + + @staticmethod + def lark_name() -> str: + return "object_elem_key_expression" + + @property + def expression(self) -> ExpressionRule: + return self._children[1] + + def serialize(self, options=SerializationOptions(), context=SerializationContext()) -> Any: + with context.modify(inside_dollar_string=True): + result = f"({self.expression.serialize(options, context)})" + if not context.inside_dollar_string: + result = to_dollar_string(result) + return result + + +class ObjectElemKeyDotAccessor(LarkRule): + + _children: Tuple[ + IdentifierRule, + Tuple[ + IdentifierRule, + DOT, + ... + ] + ] + + @staticmethod + def lark_name() -> str: + return "object_elem_key_dot_accessor" + + @property + def identifiers(self) -> List[IdentifierRule]: + return [child for child in self._children if isinstance(child, IdentifierRule)] + + def serialize(self, options=SerializationOptions(), context=SerializationContext()) -> Any: + return ".".join(identifier.serialize(options, context) for identifier in self.identifiers) + + +class ObjectElemRule(LarkRule): + + _children: Tuple[ + ObjectElemKeyRule, + Union[EQ, COLON], + ExpressionRule, + ] + + @staticmethod + def lark_name() -> str: + return "object_elem" + + @property + def key(self) -> ObjectElemKeyRule: + return self._children[0] + + @property + def expression(self): + return self._children[2] + + def serialize(self, options = SerializationOptions(), context = SerializationContext()) -> Any: + return { + self.key.serialize(options, context): self.expression.serialize(options, context) + } + + +class ObjectRule(InlineCommentMixIn): + + _children: Tuple[ + LBRACE, + Optional[NewLineOrCommentRule], + Tuple[ + ObjectElemRule, + Optional[NewLineOrCommentRule], + Optional[COMMA], + Optional[NewLineOrCommentRule], + ... + ], + RBRACE, + ] + + @staticmethod + def lark_name() -> str: + return "object" + + @property + def elements(self) -> List[ObjectElemRule]: + return [ + child for child in self.children[1:-1] if isinstance(child, ObjectElemRule) + ] + + def serialize(self, options = SerializationOptions(), context = SerializationContext()) -> Any: + if not options.wrap_objects: + result = {} + for element in self.elements: + result.update(element.serialize(options, context)) + + return result + + with context.modify(inside_dollar_string=True): + result = f"{{{", ".join( + f"{element.key.serialize(options, context)} = {element.expression.serialize(options,context)}" + for element in self.elements + )}}}" + + if not context.inside_dollar_string: + result = to_dollar_string(result) + + return result diff --git a/hcl2/rule_transformer/rules/expressions.py b/hcl2/rule_transformer/rules/expressions.py new file mode 100644 index 00000000..d89f3b3c --- /dev/null +++ b/hcl2/rule_transformer/rules/expressions.py @@ -0,0 +1,220 @@ +from abc import ABC +from copy import deepcopy +from typing import Any, Tuple, Optional + +from lark.tree import Meta + +from hcl2.rule_transformer.rules.abstract import ( + LarkToken, +) +from hcl2.rule_transformer.rules.literal_rules import BinaryOperatorRule +from hcl2.rule_transformer.rules.tokens import LPAR, RPAR, QMARK, COLON +from hcl2.rule_transformer.rules.whitespace import ( + NewLineOrCommentRule, + InlineCommentMixIn, +) +from hcl2.rule_transformer.utils import ( + wrap_into_parentheses, + to_dollar_string, + unwrap_dollar_string, + SerializationOptions, + SerializationContext, +) + + +class ExpressionRule(InlineCommentMixIn, ABC): + @staticmethod + def lark_name() -> str: + return "expression" + + def __init__(self, children, meta: Optional[Meta] = None): + super().__init__(children, meta) + + +class ExprTermRule(ExpressionRule): + + type_ = Tuple[ + Optional[LPAR], + Optional[NewLineOrCommentRule], + ExpressionRule, + Optional[NewLineOrCommentRule], + Optional[RPAR], + ] + + _children: type_ + + @staticmethod + def lark_name() -> str: + return "expr_term" + + def __init__(self, children, meta: Optional[Meta] = None): + self._parentheses = False + if ( + isinstance(children[0], LarkToken) + and children[0].lark_name() == "LPAR" + and isinstance(children[-1], LarkToken) + and children[-1].lark_name() == "RPAR" + ): + self._parentheses = True + else: + children = [None, *children, None] + self._possibly_insert_null_comments(children, [1, 3]) + super().__init__(children, meta) + + @property + def parentheses(self) -> bool: + return self._parentheses + + @property + def expression(self) -> ExpressionRule: + return self._children[2] + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + result = self.expression.serialize(options, context) + + if self.parentheses: + result = wrap_into_parentheses(result) + if not context.inside_dollar_string: + result = to_dollar_string(result) + + return result + + +class ConditionalRule(ExpressionRule): + + _children: Tuple[ + ExpressionRule, + QMARK, + Optional[NewLineOrCommentRule], + ExpressionRule, + Optional[NewLineOrCommentRule], + COLON, + Optional[NewLineOrCommentRule], + ExpressionRule, + ] + + @staticmethod + def lark_name() -> str: + return "conditional" + + def __init__(self, children, meta: Optional[Meta] = None): + self._possibly_insert_null_comments(children, [2, 4, 6]) + super().__init__(children, meta) + + @property + def condition(self) -> ExpressionRule: + return self._children[0] + + @property + def if_true(self) -> ExpressionRule: + return self._children[3] + + @property + def if_false(self) -> ExpressionRule: + return self._children[7] + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + with context.modify(inside_dollar_string=False): + result = ( + f"{self.condition.serialize(options, context)} " + f"? {self.if_true.serialize(options, context)} " + f": {self.if_false.serialize(options, context)}" + ) + + if not context.inside_dollar_string: + result = to_dollar_string(result) + + return result + + +class BinaryTermRule(ExpressionRule): + + _children: Tuple[ + BinaryOperatorRule, + Optional[NewLineOrCommentRule], + ExprTermRule, + ] + + @staticmethod + def lark_name() -> str: + return "binary_term" + + def __init__(self, children, meta: Optional[Meta] = None): + self._possibly_insert_null_comments(children, [1]) + super().__init__(children, meta) + + @property + def binary_operator(self) -> BinaryOperatorRule: + return self._children[0] + + @property + def expr_term(self) -> ExprTermRule: + return self._children[2] + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + return f"{self.binary_operator.serialize(options, context)} {self.expr_term.serialize(options, context)}" + + +class BinaryOpRule(ExpressionRule): + _children: Tuple[ + ExprTermRule, + BinaryTermRule, + Optional[NewLineOrCommentRule], + ] + + @staticmethod + def lark_name() -> str: + return "binary_op" + + @property + def expr_term(self) -> ExprTermRule: + return self._children[0] + + @property + def binary_term(self) -> BinaryTermRule: + return self._children[1] + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + + with context.modify(inside_dollar_string=True): + lhs = self.expr_term.serialize(options, context) + operator = self.binary_term.binary_operator.serialize(options, context) + rhs = self.binary_term.expr_term.serialize(options, context) + + result = f"{lhs} {operator} {rhs}" + + if not context.inside_dollar_string: + result = to_dollar_string(result) + return result + + +class UnaryOpRule(ExpressionRule): + + _children: Tuple[LarkToken, ExprTermRule] + + @staticmethod + def lark_name() -> str: + return "unary_op" + + @property + def operator(self) -> str: + return str(self._children[0]) + + @property + def expr_term(self): + return self._children[1] + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + return to_dollar_string( + f"{self.operator}{self.expr_term.serialize(options, context)}" + ) diff --git a/hcl2/rule_transformer/rules/functions.py b/hcl2/rule_transformer/rules/functions.py new file mode 100644 index 00000000..54958514 --- /dev/null +++ b/hcl2/rule_transformer/rules/functions.py @@ -0,0 +1,104 @@ +from functools import lru_cache +from typing import Any, Optional, Tuple, Union, List + +from hcl2.rule_transformer.rules.expressions import ExpressionRule +from hcl2.rule_transformer.rules.literal_rules import IdentifierRule +from hcl2.rule_transformer.rules.tokens import COMMA, ELLIPSIS, StringToken, LPAR, RPAR +from hcl2.rule_transformer.rules.whitespace import InlineCommentMixIn, NewLineOrCommentRule +from hcl2.rule_transformer.utils import SerializationOptions, SerializationContext, to_dollar_string + + +class ArgumentsRule(InlineCommentMixIn): + + _children: Tuple[ + ExpressionRule, + Tuple[ + Optional[NewLineOrCommentRule], + COMMA, + Optional[NewLineOrCommentRule], + ExpressionRule, + ... + ], + Optional[Union[COMMA, ELLIPSIS]], + Optional[NewLineOrCommentRule], + ] + + @staticmethod + def lark_name() -> str: + return "arguments" + + @property + @lru_cache(maxsize=None) + def has_ellipsis(self) -> bool: + for child in self._children[-2:]: + if isinstance(child, StringToken) and child.lark_name() == "ELLIPSIS": + return True + return False + + @property + def arguments(self) -> List[ExpressionRule]: + return [child for child in self._children if isinstance(child, ExpressionRule)] + + def serialize(self, options = SerializationOptions(), context = SerializationContext()) -> Any: + result = ", ".join([argument.serialize(options, context) for argument in self.arguments]) + if self.has_ellipsis: + result += " ..." + return result + + +class FunctionCallRule(InlineCommentMixIn): + + _children: Tuple[ + IdentifierRule, + Optional[IdentifierRule], + Optional[IdentifierRule], + LPAR, + Optional[NewLineOrCommentRule], + Optional[ArgumentsRule], + Optional[NewLineOrCommentRule], + RPAR, + ] + + @staticmethod + def lark_name() -> str: + return "function_call" + + @property + @lru_cache(maxsize=None) + def identifiers(self) -> List[IdentifierRule]: + return [child for child in self._children if isinstance(child, IdentifierRule)] + + @property + @lru_cache(maxsize=None) + def arguments(self) -> Optional[ArgumentsRule]: + for child in self._children[2:6]: + if isinstance(child, ArgumentsRule): + return child + + + def serialize(self, options = SerializationOptions(), context = SerializationContext()) -> Any: + result = ( + f"{"::".join(identifier.serialize(options, context) for identifier in self.identifiers)}" + f"({self.arguments.serialize(options, context) if self.arguments else ""})" + ) + if not context.inside_dollar_string: + result = to_dollar_string(result) + + return result + + +# class ProviderFunctionCallRule(FunctionCallRule): +# _children: Tuple[ +# IdentifierRule, +# IdentifierRule, +# IdentifierRule, +# LPAR, +# Optional[NewLineOrCommentRule], +# Optional[ArgumentsRule], +# Optional[NewLineOrCommentRule], +# RPAR, +# ] +# +# @staticmethod +# def lark_name() -> str: +# return "provider_function_call" diff --git a/hcl2/rule_transformer/rules/indexing.py b/hcl2/rule_transformer/rules/indexing.py new file mode 100644 index 00000000..7a9b53a5 --- /dev/null +++ b/hcl2/rule_transformer/rules/indexing.py @@ -0,0 +1,240 @@ +from typing import List, Optional, Tuple, Any, Union + +from lark.tree import Meta + +from hcl2.rule_transformer.rules.abstract import LarkRule +from hcl2.rule_transformer.rules.expressions import ExprTermRule, ExpressionRule +from hcl2.rule_transformer.rules.literal_rules import IdentifierRule +from hcl2.rule_transformer.rules.tokens import ( + DOT, + IntLiteral, + LSQB, + RSQB, + ATTR_SPLAT, +) +from hcl2.rule_transformer.rules.whitespace import ( + InlineCommentMixIn, + NewLineOrCommentRule, +) +from hcl2.rule_transformer.utils import ( + SerializationOptions, + to_dollar_string, + SerializationContext, +) + + +class ShortIndexRule(LarkRule): + + _children: Tuple[ + DOT, + IntLiteral, + ] + + @staticmethod + def lark_name() -> str: + return "short_index" + + @property + def index(self): + return self.children[1] + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + return f".{self.index.serialize(options)}" + + +class SqbIndexRule(InlineCommentMixIn): + _children: Tuple[ + LSQB, + Optional[NewLineOrCommentRule], + ExprTermRule, + Optional[NewLineOrCommentRule], + RSQB, + ] + + @staticmethod + def lark_name() -> str: + return "braces_index" + + @property + def index_expression(self): + return self.children[2] + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + return f"[{self.index_expression.serialize(options)}]" + + def __init__(self, children, meta: Optional[Meta] = None): + self._possibly_insert_null_comments(children, [1, 3]) + super().__init__(children, meta) + + +class IndexExprTermRule(ExpressionRule): + + _children: Tuple[ExprTermRule, SqbIndexRule] + + @staticmethod + def lark_name() -> str: + return "index_expr_term" + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + with context.modify(inside_dollar_string=True): + result = f"{self.children[0].serialize(options)}{self.children[1].serialize(options)}" + if not context.inside_dollar_string: + result = to_dollar_string(result) + return result + + +class GetAttrRule(LarkRule): + + _children: Tuple[ + DOT, + IdentifierRule, + ] + + @staticmethod + def lark_name() -> str: + return "get_attr" + + @property + def identifier(self) -> IdentifierRule: + return self._children[1] + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + return f".{self.identifier.serialize(options, context)}" + + +class GetAttrExprTermRule(ExpressionRule): + + _children: Tuple[ + ExprTermRule, + GetAttrRule, + ] + + @staticmethod + def lark_name() -> str: + return "get_attr_expr_term" + + @property + def expr_term(self) -> ExprTermRule: + return self._children[0] + + @property + def get_attr(self) -> GetAttrRule: + return self._children[1] + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + with context.modify(inside_dollar_string=True): + result = f"{self.expr_term.serialize(options, context)}{self.get_attr.serialize(options, context)}" + if not context.inside_dollar_string: + result = to_dollar_string(result) + return result + + +class AttrSplatRule(LarkRule): + _children: Tuple[ + ATTR_SPLAT, + Tuple[Union[GetAttrRule, Union[SqbIndexRule, ShortIndexRule]], ...], + ] + + @staticmethod + def lark_name() -> str: + return "attr_splat" + + @property + def get_attrs( + self, + ) -> List[Union[GetAttrRule, Union[SqbIndexRule, ShortIndexRule]]]: + return self._children[1:] + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + return ".*" + "".join( + get_attr.serialize(options, context) for get_attr in self.get_attrs + ) + + +class AttrSplatExprTermRule(ExpressionRule): + + _children: Tuple[ExprTermRule, AttrSplatRule] + + @staticmethod + def lark_name() -> str: + return "attr_splat_expr_term" + + @property + def expr_term(self) -> ExprTermRule: + return self._children[0] + + @property + def attr_splat(self) -> AttrSplatRule: + return self._children[1] + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + with context.modify(inside_dollar_string=True): + result = f"{self.expr_term.serialize(options, context)}{self.attr_splat.serialize(options, context)}" + + if not context.inside_dollar_string: + result = to_dollar_string(result) + return result + + +class FullSplatRule(LarkRule): + _children: Tuple[ + ATTR_SPLAT, + Tuple[Union[GetAttrRule, Union[SqbIndexRule, ShortIndexRule]], ...], + ] + + @staticmethod + def lark_name() -> str: + return "full_splat" + + @property + def get_attrs( + self, + ) -> List[Union[GetAttrRule, Union[SqbIndexRule, ShortIndexRule]]]: + return self._children[1:] + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + return "[*]" + "".join( + get_attr.serialize(options, context) for get_attr in self.get_attrs + ) + + +class FullSplatExprTermRule(ExpressionRule): + _children: Tuple[ExprTermRule, FullSplatRule] + + @staticmethod + def lark_name() -> str: + return "full_splat_expr_term" + + @property + def expr_term(self) -> ExprTermRule: + return self._children[0] + + @property + def attr_splat(self) -> FullSplatRule: + return self._children[1] + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + with context.modify(inside_dollar_string=True): + result = f"{self.expr_term.serialize(options, context)}{self.attr_splat.serialize(options, context)}" + + if not context.inside_dollar_string: + result = to_dollar_string(result) + return result diff --git a/hcl2/rule_transformer/rules/literal_rules.py b/hcl2/rule_transformer/rules/literal_rules.py new file mode 100644 index 00000000..baf8546f --- /dev/null +++ b/hcl2/rule_transformer/rules/literal_rules.py @@ -0,0 +1,49 @@ +from abc import ABC +from typing import Any, Tuple + +from hcl2.rule_transformer.rules.abstract import LarkRule, LarkToken +from hcl2.rule_transformer.utils import SerializationOptions, SerializationContext + + +class TokenRule(LarkRule, ABC): + + _children: Tuple[LarkToken] + + @property + def token(self) -> LarkToken: + return self._children[0] + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + return self.token.serialize() + + +class KeywordRule(TokenRule): + @staticmethod + def lark_name() -> str: + return "keyword" + + +class IdentifierRule(TokenRule): + @staticmethod + def lark_name() -> str: + return "identifier" + + +class IntLitRule(TokenRule): + @staticmethod + def lark_name() -> str: + return "int_lit" + + +class FloatLitRule(TokenRule): + @staticmethod + def lark_name() -> str: + return "float_lit" + + +class BinaryOperatorRule(TokenRule): + @staticmethod + def lark_name() -> str: + return "binary_operator" diff --git a/hcl2/rule_transformer/rules/strings.py b/hcl2/rule_transformer/rules/strings.py new file mode 100644 index 00000000..769ad5b9 --- /dev/null +++ b/hcl2/rule_transformer/rules/strings.py @@ -0,0 +1,73 @@ +from typing import Tuple, List, Any, Union + +from hcl2.rule_transformer.rules.abstract import LarkRule +from hcl2.rule_transformer.rules.expressions import ExpressionRule +from hcl2.rule_transformer.rules.tokens import ( + INTERP_START, + RBRACE, + DBLQUOTE, + STRING_CHARS, + ESCAPED_INTERPOLATION, +) +from hcl2.rule_transformer.utils import ( + SerializationOptions, + SerializationContext, + to_dollar_string, +) + + +class InterpolationRule(LarkRule): + + _children: Tuple[ + INTERP_START, + ExpressionRule, + RBRACE, + ] + + @staticmethod + def lark_name() -> str: + return "interpolation" + + @property + def expression(self): + return self.children[1] + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + return to_dollar_string(self.expression.serialize(options)) + + +class StringPartRule(LarkRule): + _children: Tuple[Union[STRING_CHARS, ESCAPED_INTERPOLATION, InterpolationRule]] + + @staticmethod + def lark_name() -> str: + return "string_part" + + @property + def content(self) -> Union[STRING_CHARS, ESCAPED_INTERPOLATION, InterpolationRule]: + return self._children[0] + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + return self.content.serialize(options, context) + + +class StringRule(LarkRule): + + _children: Tuple[DBLQUOTE, List[StringPartRule], DBLQUOTE] + + @staticmethod + def lark_name() -> str: + return "string" + + @property + def string_parts(self): + return self.children[1:-1] + + def serialize( + self, options=SerializationOptions(), context=SerializationContext() + ) -> Any: + return '"' + "".join(part.serialize() for part in self.string_parts) + '"' diff --git a/hcl2/rule_transformer/rules/tokens.py b/hcl2/rule_transformer/rules/tokens.py new file mode 100644 index 00000000..59e524f3 --- /dev/null +++ b/hcl2/rule_transformer/rules/tokens.py @@ -0,0 +1,111 @@ +from functools import lru_cache +from typing import Callable, Any, Type, Optional, Tuple + +from hcl2.rule_transformer.rules.abstract import LarkToken + + +class StringToken(LarkToken): + """ + Single run-time base class; every `StringToken["..."]` call returns a + cached subclass whose static `lark_name()` yields the given string. + """ + + @classmethod + @lru_cache(maxsize=None) + def __build_subclass(cls, name: str) -> Type["StringToken"]: + """Create a subclass with a constant `lark_name`.""" + return type( # type: ignore + f"{name}_TOKEN", + (StringToken,), + { + "__slots__": (), + "lark_name": staticmethod(lambda _n=name: _n), + }, + ) + + def __class_getitem__(cls, name: str) -> Type["StringToken"]: + if not isinstance(name, str): + raise TypeError("StringToken[...] expects a single str argument") + return cls.__build_subclass(name) + + def __init__(self, value: Optional[Any] = None): + super().__init__(value) + + @property + def serialize_conversion(self) -> Callable[[Any], str]: + return str + + +class StaticStringToken(LarkToken): + @classmethod + @lru_cache(maxsize=None) + def __build_subclass( + cls, name: str, default_value: str = None + ) -> Type["StringToken"]: + """Create a subclass with a constant `lark_name`.""" + + return type( # type: ignore + f"{name}_TOKEN", + (cls,), + { + "__slots__": (), + "lark_name": staticmethod(lambda _n=name: _n), + "_default_value": default_value, + }, + ) + + def __class_getitem__(cls, value: Tuple[str, str]) -> Type["StringToken"]: + name, default_value = value + return cls.__build_subclass(name, default_value) + + def __init__(self): + super().__init__(getattr(self, "_default_value")) + + @property + def serialize_conversion(self) -> Callable[[Any], str]: + return str + + +# explicitly define various kinds of string-based tokens for type hinting +# variable value +NAME = StringToken["NAME"] +STRING_CHARS = StringToken["STRING_CHARS"] +ESCAPED_INTERPOLATION = StringToken["ESCAPED_INTERPOLATION"] +BINARY_OP = StringToken["BINARY_OP"] +# static value +EQ = StaticStringToken[("EQ", "=")] +COLON = StaticStringToken[("COLON", ":")] +LPAR = StaticStringToken[("LPAR", "(")] +RPAR = StaticStringToken[("RPAR", ")")] +LBRACE = StaticStringToken[("LBRACE", "{")] +RBRACE = StaticStringToken[("RBRACE", "}")] +DOT = StaticStringToken[("DOT", ".")] +COMMA = StaticStringToken[("COMMA", ",")] +ELLIPSIS = StaticStringToken[("ELLIPSIS", "...")] +QMARK = StaticStringToken[("QMARK", "?")] +LSQB = StaticStringToken[("LSQB", "[")] +RSQB = StaticStringToken[("RSQB", "]")] +INTERP_START = StaticStringToken[("INTERP_START", "${")] +DBLQUOTE = StaticStringToken[("DBLQUOTE", '"')] +ATTR_SPLAT = StaticStringToken[("ATTR_SPLAT", ".*")] +FULL_SPLAT = StaticStringToken[("FULL_SPLAT", "[*]")] + + +class IntLiteral(LarkToken): + @staticmethod + def lark_name() -> str: + return "INT_LITERAL" + + @property + def serialize_conversion(self) -> Callable: + return int + + +class FloatLiteral(LarkToken): + @staticmethod + def lark_name() -> str: + return "FLOAT_LITERAL" + + @property + def serialize_conversion(self) -> Callable: + return float diff --git a/hcl2/rule_transformer/rules/tree.py b/hcl2/rule_transformer/rules/tree.py new file mode 100644 index 00000000..e39d2077 --- /dev/null +++ b/hcl2/rule_transformer/rules/tree.py @@ -0,0 +1,106 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, Any, Union + + +class LarkNode(ABC): + """Base class for all nodes in the tree""" + + def __init__(self, index: int = -1, parent: Optional["Node"] = None): + self._index = index + self._parent = parent + + @property + def parent(self) -> Optional["Node"]: + return self._parent + + @property + def index(self) -> int: + return self._index + + def set_parent(self, parent: "Node"): + self._parent = parent + + def set_index(self, index: int): + self._index = index + + @abstractmethod + def serialize(self, options=None) -> Any: + pass + + @abstractmethod + def to_lark(self) -> Any: + """Convert back to Lark representation""" + pass + + def is_leaf(self) -> bool: + """Check if this is a leaf node (atomic token)""" + return isinstance(self, LeafNode) + + def is_sequence(self) -> bool: + """Check if this is a token sequence node""" + return isinstance(self, SequenceNode) + + def is_internal(self) -> bool: + """Check if this is an internal node (grammar rule)""" + return isinstance(self, InternalNode) + + def is_atomic(self) -> bool: + """Check if this represents an atomic value (leaf or sequence)""" + return self.is_leaf() or self.is_sequence() + + +class LarkLeaf(Node, ABC): + """""" + + def __init__(self, value: Any, index: int = -1, parent: Optional[TreeNode] = None): + super().__init__(index, parent) + self._value = value + + @property + def value(self) -> Any: + return self._value + + def serialize(self, options=None) -> Any: + return self._value + + +class InternalNode(Node): + def __init__( + self, children: List[Node], index: int = -1, parent: Optional[Node] = None + ): + super().__init__(index, parent) + self._children = children or [] + + # Set parent and index for all children + for i, child in enumerate(self._children): + if child is not None: + child.set_parent(self) + child.set_index(i) + + @property + def children(self) -> List[Node]: + return self._children + + def add_child(self, child: Node): + """Add a child to this internal node""" + child.set_parent(self) + child.set_index(len(self._children)) + self._children.append(child) + + def remove_child(self, index: int) -> Optional[Node]: + """Remove child at given index""" + if 0 <= index < len(self._children): + child = self._children.pop(index) + if child: + child.set_parent(None) + # Update indices for remaining children + for i in range(index, len(self._children)): + if self._children[i]: + self._children[i].set_index(i) + return child + return None + + @abstractmethod + def rule_name(self) -> str: + """The name of the grammar rule this represents""" + pass diff --git a/hcl2/rule_transformer/rules/whitespace.py b/hcl2/rule_transformer/rules/whitespace.py new file mode 100644 index 00000000..fa24355c --- /dev/null +++ b/hcl2/rule_transformer/rules/whitespace.py @@ -0,0 +1,68 @@ +from abc import ABC +from typing import Optional, List, Any, Tuple + +from hcl2.rule_transformer.rules.abstract import LarkToken, LarkRule +from hcl2.rule_transformer.rules.literal_rules import TokenRule +from hcl2.rule_transformer.utils import SerializationOptions + + +class NewLineOrCommentRule(TokenRule): + @staticmethod + def lark_name() -> str: + return "new_line_or_comment" + + @classmethod + def from_string(cls, string: str) -> "NewLineOrCommentRule": + return cls([LarkToken("NL_OR_COMMENT", string)]) + + def to_list( + self, options: SerializationOptions = SerializationOptions() + ) -> Optional[List[str]]: + comment = self.serialize(options) + if comment == "\n": + return None + + comments = comment.split("\n") + + result = [] + for comment in comments: + comment = comment.strip() + + for delimiter in ("//", "/*", "#"): + + if comment.startswith(delimiter): + comment = comment[len(delimiter) :] + + if comment.endswith("*/"): + comment = comment[:-2] + + if comment != "": + result.append(comment.strip()) + + return result + + +class InlineCommentMixIn(LarkRule, ABC): + def _possibly_insert_null_comments(self, children: List, indexes: List[int] = None): + for index in indexes: + try: + child = children[index] + except IndexError: + children.insert(index, None) + else: + if not isinstance(child, NewLineOrCommentRule): + children.insert(index, None) + + def inline_comments(self): + result = [] + for child in self._children: + + if isinstance(child, NewLineOrCommentRule): + comments = child.to_list() + if comments is not None: + result.extend(comments) + + elif isinstance(child, InlineCommentMixIn): + result.extend(child.inline_comments()) + + return result diff --git a/hcl2/rule_transformer/transformer.py b/hcl2/rule_transformer/transformer.py new file mode 100644 index 00000000..a7d91605 --- /dev/null +++ b/hcl2/rule_transformer/transformer.py @@ -0,0 +1,228 @@ +# pylint: disable=missing-function-docstring,unused-argument +from lark import Token, Tree, v_args, Transformer, Discard +from lark.tree import Meta + +from hcl2.rule_transformer.rules.base import ( + StartRule, + BodyRule, + BlockRule, + AttributeRule, +) +from hcl2.rule_transformer.rules.containers import ( + ObjectRule, + ObjectElemRule, + ObjectElemKeyRule, + TupleRule, + ObjectElemKeyExpressionRule, + ObjectElemKeyDotAccessor, +) +from hcl2.rule_transformer.rules.expressions import ( + BinaryTermRule, + UnaryOpRule, + BinaryOpRule, + ExprTermRule, + ConditionalRule, +) +from hcl2.rule_transformer.rules.functions import ArgumentsRule, FunctionCallRule +from hcl2.rule_transformer.rules.indexing import ( + IndexExprTermRule, + SqbIndexRule, + ShortIndexRule, + GetAttrRule, + GetAttrExprTermRule, + AttrSplatExprTermRule, + AttrSplatRule, + FullSplatRule, + FullSplatExprTermRule, +) +from hcl2.rule_transformer.rules.literal_rules import ( + FloatLitRule, + IntLitRule, + IdentifierRule, + BinaryOperatorRule, +) +from hcl2.rule_transformer.rules.strings import ( + InterpolationRule, + StringRule, + StringPartRule, +) +from hcl2.rule_transformer.rules.tokens import ( + NAME, + IntLiteral, + FloatLiteral, + StringToken, +) +from hcl2.rule_transformer.rules.whitespace import NewLineOrCommentRule + + +class RuleTransformer(Transformer): + """Takes a syntax tree generated by the parser and + transforms it to a tree of LarkRule instances + """ + + with_meta: bool + + def transform(self, tree: Tree) -> StartRule: + return super().transform(tree) + + def __init__(self, discard_new_line_or_comments: bool = False): + super().__init__() + self.discard_new_line_or_comments = discard_new_line_or_comments + + def __default_token__(self, token: Token) -> StringToken: + return StringToken[token.type](token.value) + + def FLOAT_LITERAL(self, token: Token) -> FloatLiteral: + return FloatLiteral(token.value) + + def NAME(self, token: Token) -> NAME: + return NAME(token.value) + + def INT_LITERAL(self, token: Token) -> IntLiteral: + return IntLiteral(token.value) + + @v_args(meta=True) + def start(self, meta: Meta, args) -> StartRule: + return StartRule(args, meta) + + @v_args(meta=True) + def body(self, meta: Meta, args) -> BodyRule: + return BodyRule(args, meta) + + @v_args(meta=True) + def block(self, meta: Meta, args) -> BlockRule: + return BlockRule(args, meta) + + @v_args(meta=True) + def attribute(self, meta: Meta, args) -> AttributeRule: + return AttributeRule(args, meta) + + @v_args(meta=True) + def new_line_or_comment(self, meta: Meta, args) -> NewLineOrCommentRule: + if self.discard_new_line_or_comments: + return Discard + return NewLineOrCommentRule(args, meta) + + @v_args(meta=True) + def identifier(self, meta: Meta, args) -> IdentifierRule: + return IdentifierRule(args, meta) + + @v_args(meta=True) + def int_lit(self, meta: Meta, args) -> IntLitRule: + return IntLitRule(args, meta) + + @v_args(meta=True) + def float_lit(self, meta: Meta, args) -> FloatLitRule: + return FloatLitRule(args, meta) + + @v_args(meta=True) + def string(self, meta: Meta, args) -> StringRule: + return StringRule(args, meta) + + @v_args(meta=True) + def string_part(self, meta: Meta, args) -> StringPartRule: + return StringPartRule(args, meta) + + @v_args(meta=True) + def interpolation(self, meta: Meta, args) -> InterpolationRule: + return InterpolationRule(args, meta) + + @v_args(meta=True) + def expr_term(self, meta: Meta, args) -> ExprTermRule: + return ExprTermRule(args, meta) + + @v_args(meta=True) + def conditional(self, meta: Meta, args) -> ConditionalRule: + return ConditionalRule(args, meta) + + @v_args(meta=True) + def binary_operator(self, meta: Meta, args) -> BinaryOperatorRule: + return BinaryOperatorRule(args, meta) + + @v_args(meta=True) + def binary_term(self, meta: Meta, args) -> BinaryTermRule: + return BinaryTermRule(args, meta) + + @v_args(meta=True) + def unary_op(self, meta: Meta, args) -> UnaryOpRule: + return UnaryOpRule(args, meta) + + @v_args(meta=True) + def binary_op(self, meta: Meta, args) -> BinaryOpRule: + return BinaryOpRule(args, meta) + + @v_args(meta=True) + def tuple(self, meta: Meta, args) -> TupleRule: + return TupleRule(args, meta) + + @v_args(meta=True) + def object(self, meta: Meta, args) -> ObjectRule: + return ObjectRule(args, meta) + + @v_args(meta=True) + def object_elem(self, meta: Meta, args) -> ObjectElemRule: + return ObjectElemRule(args, meta) + + @v_args(meta=True) + def object_elem_key(self, meta: Meta, args) -> ObjectElemKeyRule: + return ObjectElemKeyRule(args, meta) + + @v_args(meta=True) + def object_elem_key_expression( + self, meta: Meta, args + ) -> ObjectElemKeyExpressionRule: + return ObjectElemKeyExpressionRule(args, meta) + + @v_args(meta=True) + def object_elem_key_dot_accessor( + self, meta: Meta, args + ) -> ObjectElemKeyDotAccessor: + return ObjectElemKeyDotAccessor(args, meta) + + @v_args(meta=True) + def arguments(self, meta: Meta, args) -> ArgumentsRule: + return ArgumentsRule(args, meta) + + @v_args(meta=True) + def function_call(self, meta: Meta, args) -> FunctionCallRule: + return FunctionCallRule(args, meta) + + # @v_args(meta=True) + # def provider_function_call(self, meta: Meta, args) -> ProviderFunctionCallRule: + # return ProviderFunctionCallRule(args, meta) + + @v_args(meta=True) + def index_expr_term(self, meta: Meta, args) -> IndexExprTermRule: + return IndexExprTermRule(args, meta) + + @v_args(meta=True) + def braces_index(self, meta: Meta, args) -> SqbIndexRule: + return SqbIndexRule(args, meta) + + @v_args(meta=True) + def short_index(self, meta: Meta, args) -> ShortIndexRule: + return ShortIndexRule(args, meta) + + @v_args(meta=True) + def get_attr(self, meta: Meta, args) -> GetAttrRule: + return GetAttrRule(args, meta) + + @v_args(meta=True) + def get_attr_expr_term(self, meta: Meta, args) -> GetAttrExprTermRule: + return GetAttrExprTermRule(args, meta) + + @v_args(meta=True) + def attr_splat(self, meta: Meta, args) -> AttrSplatRule: + return AttrSplatRule(args, meta) + + @v_args(meta=True) + def attr_splat_expr_term(self, meta: Meta, args) -> AttrSplatExprTermRule: + return AttrSplatExprTermRule(args, meta) + + @v_args(meta=True) + def full_splat(self, meta: Meta, args) -> FullSplatRule: + return FullSplatRule(args, meta) + + @v_args(meta=True) + def full_splat_expr_term(self, meta: Meta, args) -> FullSplatExprTermRule: + return FullSplatExprTermRule(args, meta) diff --git a/hcl2/rule_transformer/utils.py b/hcl2/rule_transformer/utils.py new file mode 100644 index 00000000..404bdcdd --- /dev/null +++ b/hcl2/rule_transformer/utils.py @@ -0,0 +1,70 @@ +from contextlib import contextmanager +from dataclasses import dataclass, replace +from typing import Generator + + +@dataclass +class SerializationOptions: + with_comments: bool = True + with_meta: bool = False + wrap_objects: bool = False + wrap_tuples: bool = False + explicit_blocks: bool = True + + +@dataclass +class DeserializationOptions: + pass + + +@dataclass +class SerializationContext: + inside_dollar_string: bool = False + + def replace(self, **kwargs) -> "SerializationContext": + return replace(self, **kwargs) + + @contextmanager + def copy(self, **kwargs) -> Generator["SerializationContext", None, None]: + """Context manager that yields a modified copy of the context""" + modified_context = self.replace(**kwargs) + yield modified_context + + @contextmanager + def modify(self, **kwargs): + original_values = {key: getattr(self, key) for key in kwargs} + + for key, value in kwargs.items(): + setattr(self, key, value) + + try: + yield + finally: + # Restore original values + for key, value in original_values.items(): + setattr(self, key, value) + + +def is_dollar_string(value: str) -> bool: + if not isinstance(value, str): + return False + return value.startswith("${") and value.endswith("}") + + +def to_dollar_string(value: str) -> str: + if not is_dollar_string(value): + return f"${{{value}}}" + return value + + +def unwrap_dollar_string(value: str) -> str: + if is_dollar_string(value): + return value[2:-1] + return value + + +def wrap_into_parentheses(value: str) -> str: + if is_dollar_string(value): + value = unwrap_dollar_string(value) + return to_dollar_string(f"({value})") + return f"({value})" diff --git a/test/helpers/hcl2_helper.py b/test/helpers/hcl2_helper.py index 5acee1e7..c39ee7fb 100644 --- a/test/helpers/hcl2_helper.py +++ b/test/helpers/hcl2_helper.py @@ -3,7 +3,7 @@ from lark import Tree from hcl2.parser import parser -from hcl2.transformer import DictTransformer +from hcl2.dict_transformer import DictTransformer class Hcl2Helper: diff --git a/test/unit/test_dict_transformer.py b/test/unit/test_dict_transformer.py index 122332eb..baad5ba9 100644 --- a/test/unit/test_dict_transformer.py +++ b/test/unit/test_dict_transformer.py @@ -2,7 +2,7 @@ from unittest import TestCase -from hcl2.transformer import DictTransformer +from hcl2.dict_transformer import DictTransformer class TestDictTransformer(TestCase):