From 6456451ff69c1bc849120724fe8e47d1ecdab66f Mon Sep 17 00:00:00 2001 From: memento Date: Sat, 3 Dec 2022 15:46:49 -0600 Subject: [PATCH 1/8] (fix) Proposing a workaround for superfluous indentation on multi-lines --- environment.yml | 6 ++++ refactor/actions.py | 5 ++-- refactor/ast.py | 26 +++++++++++++++- tests/test_complete_rules.py | 58 ++++++++++++++++++++++++++++-------- 4 files changed, 79 insertions(+), 16 deletions(-) create mode 100644 environment.yml diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..2aa6d92 --- /dev/null +++ b/environment.yml @@ -0,0 +1,6 @@ +name: refactor +channels: + - defaults +dependencies: + - python>=3.8.2 + - pytest diff --git a/refactor/actions.py b/refactor/actions.py index 13e8d76..3cde7f2 100644 --- a/refactor/actions.py +++ b/refactor/actions.py @@ -92,8 +92,9 @@ def apply(self, context: Context, source: str) -> str: indentation, start_prefix = find_indent(source_lines[0][:col_offset]) end_suffix = source_lines[-1][end_col_offset:] replacement = split_lines(self._resynthesize(context)) - replacement.apply_indentation( - indentation, start_prefix=start_prefix, end_suffix=end_suffix + # Applies the block indentation only if the replacement lines are different + replacement.apply_indentation_from_source( + indentation, source_lines.data, start_prefix=start_prefix, end_suffix=end_suffix ) lines[view] = replacement diff --git a/refactor/ast.py b/refactor/ast.py index 3d23262..47f854f 100644 --- a/refactor/ast.py +++ b/refactor/ast.py @@ -10,7 +10,7 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from functools import cached_property -from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast +from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast, List from refactor import common @@ -51,6 +51,30 @@ def apply_indentation( if len(self.data) >= 1: self.data[-1] += str(end_suffix) # type: ignore + def apply_indentation_from_source( + self, + indentation: StringType, + source_data: List[StringType], + *, + start_prefix: AnyStringType = "", + end_suffix: AnyStringType = "", + ) -> None: + """Apply the given indentation only if the corresponding line in the source is different, + optionally with start and end prefixes to the bound source lines. + """ + + def _is_original(i: int) -> bool: + return len(source_data) < index and self.data[i].replace(" ", "") == source_data[i].replace(" ", "") + + for index, line in enumerate(self.data): + if index == 0 and not _is_original(index): + self.data[index] = indentation + str(start_prefix) + str(line) # type: ignore + elif _is_original(index): + self.data[index] = indentation + line # type: ignore + + if len(self.data) >= 1: + self.data[-1] += str(end_suffix) # type: ignore + @cached_property def _newline_type(self) -> str: """Guess the used newline type.""" diff --git a/tests/test_complete_rules.py b/tests/test_complete_rules.py index b46c14e..34bec0a 100644 --- a/tests/test_complete_rules.py +++ b/tests/test_complete_rules.py @@ -296,6 +296,37 @@ def match(self, node): return AsyncifierAction(node) +class AwaitifierAction(LazyReplace): + def build(self): + if isinstance(self.node, ast.Expr): + self.node.value = ast.Await(self.node.value) + return self.node + if isinstance(self.node, ast.Call): + new_node = ast.Await(self.node) + return new_node + + +class MakeCallAwait(Rule): + INPUT_SOURCE = """ + def somefunc(): + call( + arg0, + arg1) + """ + + EXPECTED_SOURCE = """ + def somefunc(): + await call( + arg0, + arg1) + """ + + def match(self, node): + assert isinstance(node, ast.Expr) + assert isinstance(node.value, ast.Call) + return AwaitifierAction(node) + + class OnlyKeywordArgumentDefaultNotSetCheckRule(Rule): context_providers = (context.Scope,) @@ -944,19 +975,20 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]: @pytest.mark.parametrize( "rule", [ - ReplaceNexts, - ReplacePlaceholders, - PropagateConstants, - TypingAutoImporter, - MakeFunctionAsync, - OnlyKeywordArgumentDefaultNotSetCheckRule, - InternalizeFunctions, - RemoveDeadCode, - RenameImportAndDownstream, - AssertEncoder, - PropagateAndDelete, - FoldMyConstants, - AtomicTryBlock, + #ReplaceNexts, + #ReplacePlaceholders, + #PropagateConstants, + #TypingAutoImporter, + #MakeFunctionAsync, + MakeCallAwait, + #OnlyKeywordArgumentDefaultNotSetCheckRule, + #InternalizeFunctions, + #RemoveDeadCode, + #RenameImportAndDownstream, + #AssertEncoder, + #PropagateAndDelete, + #FoldMyConstants, + #AtomicTryBlock, ], ) def test_complete_rules(rule, tmp_path): From 73d5bbaf836943796f1f6931a77acf66516484fc Mon Sep 17 00:00:00 2001 From: memento Date: Sat, 3 Dec 2022 17:33:44 -0600 Subject: [PATCH 2/8] (fix) Proposing a workaround for superfluous indentation on multi-lines --- refactor/ast.py | 7 ++++--- refactor/common.py | 10 ++++++++++ tests/test_complete_rules.py | 32 +++++++++++++++++--------------- 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/refactor/ast.py b/refactor/ast.py index 47f854f..53ac38f 100644 --- a/refactor/ast.py +++ b/refactor/ast.py @@ -13,6 +13,7 @@ from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast, List from refactor import common +from refactor.common import find_common_chars DEFAULT_ENCODING = "utf-8" @@ -64,12 +65,12 @@ def apply_indentation_from_source( """ def _is_original(i: int) -> bool: - return len(source_data) < index and self.data[i].replace(" ", "") == source_data[i].replace(" ", "") + return index < len(source_data) and str(self.data[i]) == find_common_chars(str(self.data[i]), str(source_data[i].data)) for index, line in enumerate(self.data): - if index == 0 and not _is_original(index): + if index == 0: self.data[index] = indentation + str(start_prefix) + str(line) # type: ignore - elif _is_original(index): + elif not _is_original(index): self.data[index] = indentation + line # type: ignore if len(self.data) >= 1: diff --git a/refactor/common.py b/refactor/common.py index 68dbfaf..35135c5 100644 --- a/refactor/common.py +++ b/refactor/common.py @@ -176,6 +176,16 @@ def find_indent(source: str) -> tuple[str, str]: return source[:index], source[index:] +def find_common_chars(source: str, compare: str) -> str: + """Finds the common characters starting the 2 strings""" + index: int = 0 + for index, char in enumerate(source): + if index > len(compare) or char != compare[index]: + index -= 1 + break + return source[:index+1] + + def find_closest(node: ast.AST, *targets: ast.AST) -> ast.AST: """Find the closest node to the given ``node`` from the given sequence of ``targets`` (uses absolute distance from starting points).""" diff --git a/tests/test_complete_rules.py b/tests/test_complete_rules.py index 34bec0a..4e078b8 100644 --- a/tests/test_complete_rules.py +++ b/tests/test_complete_rules.py @@ -311,14 +311,14 @@ class MakeCallAwait(Rule): def somefunc(): call( arg0, - arg1) + arg1) # Intentional mis-alignment """ EXPECTED_SOURCE = """ def somefunc(): await call( arg0, - arg1) + arg1) # Intentional mis-alignment """ def match(self, node): @@ -967,6 +967,8 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]: new_trys.append(new_try) first_try, *remaining_trys = new_trys + print(ast.unparse(node)) + print(ast.unparse(first_try)) yield Replace(node, first_try) for remaining_try in reversed(remaining_trys): yield InsertAfter(node, remaining_try) @@ -975,20 +977,20 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]: @pytest.mark.parametrize( "rule", [ - #ReplaceNexts, - #ReplacePlaceholders, - #PropagateConstants, - #TypingAutoImporter, - #MakeFunctionAsync, + ReplaceNexts, + ReplacePlaceholders, + PropagateConstants, + TypingAutoImporter, + MakeFunctionAsync, MakeCallAwait, - #OnlyKeywordArgumentDefaultNotSetCheckRule, - #InternalizeFunctions, - #RemoveDeadCode, - #RenameImportAndDownstream, - #AssertEncoder, - #PropagateAndDelete, - #FoldMyConstants, - #AtomicTryBlock, + OnlyKeywordArgumentDefaultNotSetCheckRule, + InternalizeFunctions, + RemoveDeadCode, + RenameImportAndDownstream, + AssertEncoder, + PropagateAndDelete, + FoldMyConstants, + AtomicTryBlock, ], ) def test_complete_rules(rule, tmp_path): From 5193eb7a54e69a128ffdd9f3b9ba85c362e517bf Mon Sep 17 00:00:00 2001 From: memento Date: Sat, 24 Dec 2022 17:29:25 -0600 Subject: [PATCH 3/8] Fix for multi-line F-string --- refactor/ast.py | 4 +++- tests/test_complete_rules.py | 42 +++++++++++++++++++++++++++++------- 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/refactor/ast.py b/refactor/ast.py index 53ac38f..b0082c1 100644 --- a/refactor/ast.py +++ b/refactor/ast.py @@ -65,7 +65,9 @@ def apply_indentation_from_source( """ def _is_original(i: int) -> bool: - return index < len(source_data) and str(self.data[i]) == find_common_chars(str(self.data[i]), str(source_data[i].data)) + common_chars: str = find_common_chars(str(self.data[i]), str(source_data[i].data)) + is_multiline_string: int = str(self.data[i]).find(common_chars) == 0 and common_chars in ["'''", '"""'] + return i < len(source_data) and (str(self.data[i]) == common_chars or is_multiline_string) for index, line in enumerate(self.data): if index == 0: diff --git a/tests/test_complete_rules.py b/tests/test_complete_rules.py index 4e078b8..529a93b 100644 --- a/tests/test_complete_rules.py +++ b/tests/test_complete_rules.py @@ -353,7 +353,7 @@ def match(self, node: ast.AST) -> BaseAction | None: assert any(kw_default is None for kw_default in node.args.kw_defaults) if isinstance(node, ast.Lambda) and not ( - isinstance(node.body, ast.Name) and isinstance(node.body.ctx, ast.Load) + isinstance(node.body, ast.Name) and isinstance(node.body.ctx, ast.Load) ): scope = self.context["scope"].resolve(node.body) scope.definitions.get(node.body.id, []) @@ -362,8 +362,8 @@ def match(self, node: ast.AST) -> BaseAction | None: for stmt in node.body: for identifier in ast.walk(stmt): if not ( - isinstance(identifier, ast.Name) - and isinstance(identifier.ctx, ast.Load) + isinstance(identifier, ast.Name) + and isinstance(identifier.ctx, ast.Load) ): continue @@ -629,13 +629,13 @@ class DownstreamAnalyzer(Representative): context_providers = (context.Scope,) def iter_dependents( - self, name: str, source: ast.Import | ast.ImportFrom + self, name: str, source: ast.Import | ast.ImportFrom ) -> Iterator[ast.Name]: for node in ast.walk(self.context.tree): if ( - isinstance(node, ast.Name) - and isinstance(node.ctx, ast.Load) - and node.id == name + isinstance(node, ast.Name) + and isinstance(node.ctx, ast.Load) + and node.id == name ): node_scope = self.context.scope.resolve(node) definitions = node_scope.get_definitions(name) @@ -730,7 +730,7 @@ def match(self, node: ast.AST) -> Iterator[Replace]: [alias] = aliases for dependent in self.context.downstream_analyzer.iter_dependents( - alias.asname or alias.name, node + alias.asname or alias.name, node ): yield Replace(dependent, ast.Name("b", ast.Load())) @@ -974,6 +974,31 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]: yield InsertAfter(node, remaining_try) +class WrapInMultilineFstring(Rule): + INPUT_SOURCE = ''' +def f(): + return """ +a +""" +''' + EXPECTED_SOURCE = ''' +def f(): + return F(""" +a +""") +''' + + def match(self, node): + assert isinstance(node, ast.Constant) + + # Prevent wrapping F-strings that are already wrapped in F() + # Otherwise you get infinite F(F(F(F(...)))) + parent = self.context.ancestry.get_parent(node) + assert not (isinstance(parent, ast.Call) and isinstance(parent.func, ast.Name) and parent.func.id == 'F') + + return Replace(node, ast.Call(func=ast.Name(id="F"), args=[node], keywords=[])) + + @pytest.mark.parametrize( "rule", [ @@ -991,6 +1016,7 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]: PropagateAndDelete, FoldMyConstants, AtomicTryBlock, + WrapInMultilineFstring, ], ) def test_complete_rules(rule, tmp_path): From 89ee04d37e6907afa0fe8dfdffa704bfc7f32beb Mon Sep 17 00:00:00 2001 From: memento Date: Mon, 26 Dec 2022 09:48:47 -0600 Subject: [PATCH 4/8] Addressing the feedback from review --- environment.yml | 6 ----- refactor/actions.py | 17 ++++++------- refactor/ast.py | 49 +++++++++++++----------------------- refactor/common.py | 10 -------- tests/test_complete_rules.py | 2 -- 5 files changed, 26 insertions(+), 58 deletions(-) delete mode 100644 environment.yml diff --git a/environment.yml b/environment.yml deleted file mode 100644 index 2aa6d92..0000000 --- a/environment.yml +++ /dev/null @@ -1,6 +0,0 @@ -name: refactor -channels: - - defaults -dependencies: - - python>=3.8.2 - - pytest diff --git a/refactor/actions.py b/refactor/actions.py index 3cde7f2..4905a7a 100644 --- a/refactor/actions.py +++ b/refactor/actions.py @@ -89,12 +89,11 @@ def apply(self, context: Context, source: str) -> str: view = slice(lineno - 1, end_lineno) source_lines = lines[view] - indentation, start_prefix = find_indent(source_lines[0][:col_offset]) - end_suffix = source_lines[-1][end_col_offset:] replacement = split_lines(self._resynthesize(context)) - # Applies the block indentation only if the replacement lines are different - replacement.apply_indentation_from_source( - indentation, source_lines.data, start_prefix=start_prefix, end_suffix=end_suffix + # Applies the block indentation only if the replacement lines are different from source + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), ) lines[view] = replacement @@ -169,12 +168,12 @@ class LazyInsertAfter(_LazyActionMixin[ast.stmt, ast.stmt]): def apply(self, context: Context, source: str) -> str: lines = split_lines(source, encoding=context.file_info.get_encoding()) - indentation, start_prefix = find_indent( - lines[self.node.lineno - 1][: self.node.col_offset] - ) replacement = split_lines(context.unparse(self.build())) - replacement.apply_indentation(indentation, start_prefix=start_prefix) + replacement.apply_source_formatting( + source_lines=lines, + markers=(self.node.lineno - 1, self.node.col_offset, None), + ) original_node_end = cast(int, self.node.end_lineno) - 1 if lines[original_node_end].endswith(lines._newline_type): diff --git a/refactor/ast.py b/refactor/ast.py index b0082c1..8e3be8c 100644 --- a/refactor/ast.py +++ b/refactor/ast.py @@ -10,10 +10,12 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from functools import cached_property -from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast, List +from os.path import commonprefix +from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast, List, Tuple + +from refactor.common import find_indent from refactor import common -from refactor.common import find_common_chars DEFAULT_ENCODING = "utf-8" @@ -33,46 +35,31 @@ def join(self) -> str: """Return the combined source code.""" return "".join(map(str, self.lines)) - def apply_indentation( + def apply_source_formatting( self, - indentation: StringType, + source_lines: Lines, *, - start_prefix: AnyStringType = "", - end_suffix: AnyStringType = "", + markers: Tuple[int, int, int | None] = None, ) -> None: - """Apply the given indentation, optionally with start and end prefixes - to the bound source lines.""" - - for index, line in enumerate(self.data): - if index == 0: - self.data[index] = indentation + str(start_prefix) + str(line) # type: ignore - else: - self.data[index] = indentation + line # type: ignore + """Apply the indentation from source_lines when the first several characters match - if len(self.data) >= 1: - self.data[-1] += str(end_suffix) # type: ignore - - def apply_indentation_from_source( - self, - indentation: StringType, - source_data: List[StringType], - *, - start_prefix: AnyStringType = "", - end_suffix: AnyStringType = "", - ) -> None: - """Apply the given indentation only if the corresponding line in the source is different, - optionally with start and end prefixes to the bound source lines. + :param source_lines: Original lines in source code + :param markers: Indentation and prefix parameters. Tuple of (start line, col_offset, end_suffix | None) """ - def _is_original(i: int) -> bool: - common_chars: str = find_common_chars(str(self.data[i]), str(source_data[i].data)) + def not_original(i: int) -> bool: + common_chars: str = commonprefix([str(self.data[i]), str(source_lines.data[i].data)]) is_multiline_string: int = str(self.data[i]).find(common_chars) == 0 and common_chars in ["'''", '"""'] - return i < len(source_data) and (str(self.data[i]) == common_chars or is_multiline_string) + return not (i < len(source_lines.data) and (str(self.data[i]) == common_chars or is_multiline_string)) + + indentation, start_prefix = find_indent(source_lines[markers[0]][:markers[1]]) + end_suffix = "" if markers[2] is None else source_lines[-1][markers[2]:] for index, line in enumerate(self.data): + indentation = indentation if not_original(index) else "" if index == 0: self.data[index] = indentation + str(start_prefix) + str(line) # type: ignore - elif not _is_original(index): + else: self.data[index] = indentation + line # type: ignore if len(self.data) >= 1: diff --git a/refactor/common.py b/refactor/common.py index 35135c5..68dbfaf 100644 --- a/refactor/common.py +++ b/refactor/common.py @@ -176,16 +176,6 @@ def find_indent(source: str) -> tuple[str, str]: return source[:index], source[index:] -def find_common_chars(source: str, compare: str) -> str: - """Finds the common characters starting the 2 strings""" - index: int = 0 - for index, char in enumerate(source): - if index > len(compare) or char != compare[index]: - index -= 1 - break - return source[:index+1] - - def find_closest(node: ast.AST, *targets: ast.AST) -> ast.AST: """Find the closest node to the given ``node`` from the given sequence of ``targets`` (uses absolute distance from starting points).""" diff --git a/tests/test_complete_rules.py b/tests/test_complete_rules.py index 529a93b..6263267 100644 --- a/tests/test_complete_rules.py +++ b/tests/test_complete_rules.py @@ -967,8 +967,6 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]: new_trys.append(new_try) first_try, *remaining_trys = new_trys - print(ast.unparse(node)) - print(ast.unparse(first_try)) yield Replace(node, first_try) for remaining_try in reversed(remaining_trys): yield InsertAfter(node, remaining_try) From caeacad39d5b60f10b38b5228fe3f2997a322939 Mon Sep 17 00:00:00 2001 From: memento Date: Tue, 27 Dec 2022 19:07:18 -0600 Subject: [PATCH 5/8] Suggested simplification --- refactor/ast.py | 30 ++-- tests/test_ast.py | 327 ++++++++++++++++++++++++++++++++++- tests/test_complete_rules.py | 2 +- 3 files changed, 340 insertions(+), 19 deletions(-) diff --git a/refactor/ast.py b/refactor/ast.py index 8e3be8c..9035993 100644 --- a/refactor/ast.py +++ b/refactor/ast.py @@ -36,10 +36,10 @@ def join(self) -> str: return "".join(map(str, self.lines)) def apply_source_formatting( - self, - source_lines: Lines, - *, - markers: Tuple[int, int, int | None] = None, + self, + source_lines: Lines, + *, + markers: Tuple[int, int, int | None] = None, ) -> None: """Apply the indentation from source_lines when the first several characters match @@ -47,21 +47,21 @@ def apply_source_formatting( :param markers: Indentation and prefix parameters. Tuple of (start line, col_offset, end_suffix | None) """ - def not_original(i: int) -> bool: - common_chars: str = commonprefix([str(self.data[i]), str(source_lines.data[i].data)]) - is_multiline_string: int = str(self.data[i]).find(common_chars) == 0 and common_chars in ["'''", '"""'] - return not (i < len(source_lines.data) and (str(self.data[i]) == common_chars or is_multiline_string)) - indentation, start_prefix = find_indent(source_lines[markers[0]][:markers[1]]) end_suffix = "" if markers[2] is None else source_lines[-1][markers[2]:] for index, line in enumerate(self.data): - indentation = indentation if not_original(index) else "" + if index < len(source_lines): + original_line = source_lines[index] + else: + original_line = None + if index == 0: self.data[index] = indentation + str(start_prefix) + str(line) # type: ignore + elif original_line is not None and original_line.startswith(line): + self.data[index] = line # type: ignore else: self.data[index] = indentation + line # type: ignore - if len(self.data) >= 1: self.data[-1] += str(end_suffix) # type: ignore @@ -91,7 +91,7 @@ def __getitem__(self, index: SupportsIndex | slice) -> SourceSegment: # re-implements the direct indexing as slicing (e.g. a[1] is a[1:2], with # error handling). direct_index = operator.index(index) - view = raw_line[direct_index : direct_index + 1].decode( + view = raw_line[direct_index: direct_index + 1].decode( encoding=self.encoding ) if not view: @@ -214,9 +214,9 @@ def maybe_retrieve(self, node: ast.AST) -> bool: @contextmanager def _collect_stmt_comments(self, node: ast.AST) -> Iterator[None]: def _write_if_unseen_comment( - line_no: int, - line: str, - comment_begin: int, + line_no: int, + line: str, + comment_begin: int, ) -> None: if line_no in self._visited_comment_lines: # We have already written this comment as the diff --git a/tests/test_ast.py b/tests/test_ast.py index 1e5dc1d..a48c344 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -3,11 +3,13 @@ import ast import textwrap import tokenize +from pathlib import Path import pytest +from refactor.common import position_for, clone -from refactor import common -from refactor.ast import BaseUnparser, PreciseUnparser, split_lines +from refactor import common, Rule, Session, Replace, Context +from refactor.ast import BaseUnparser, PreciseUnparser, split_lines, DEFAULT_ENCODING def test_split_lines(): @@ -69,7 +71,7 @@ def test_split_lines_with_encoding(case): else: start_line = lines[lineno][col_offset:] end_line = lines[end_lineno][:end_col_offset] - match = start_line + lines[lineno + 1 : end_lineno].join() + end_line + match = start_line + lines[lineno + 1: end_lineno].join() + end_line assert str(match) == ast.get_source_segment(case, node) @@ -169,6 +171,7 @@ def test_precise_unparser_indented_literals(): """\ def func(): if something: + # On change, comments are removed print( "bleh" "zoom" @@ -240,3 +243,321 @@ def foo(): base = PreciseUnparser(source=source) assert base.unparse(tree) + "\n" == expected_src + + +def test_precise_unparser_custom_indent_no_changes(): + source = """def func(): + if something: + # Arguments have custom indentation + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + # Arguments have custom indentation + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + tree = ast.parse(source) + + base = PreciseUnparser(source=source) + assert base.unparse(tree) + "\n" == expected_src + + +def test_precise_unparser_custom_indent_del(): + source = """def func(): + if something: + # Arguments have custom indentation + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + print(call(.1), maybe+something_else_that_is_very_very_very_long, thing . a) +""" + + tree = ast.parse(source) + del tree.body[0].body[0].body[0].value.args[2] + + base = PreciseUnparser(source=source) + assert base.unparse(tree) + "\n" == expected_src + + +def test_apply_source_formatting_maintains_with_await_0(): + source = """def func(): + if something: + # Comments are not retrieved for a "new node". Maybe we need a "barely new" check? + print( + call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + await print( + call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + source_lines = split_lines(source) + source_tree = ast.parse(source) + source_tree = ast.fix_missing_locations(source_tree) + + context = Context(source, source_tree) + + awaited_print = source_tree + awaited_print.body[0].body[0].body[0] = ast.Expr(ast.Await(source_tree.body[0].body[0].body[0].value)) + + (_, col_offset, _, end_col_offset,) = position_for(source_tree.body[0]) + replacement = split_lines(context.unparse(awaited_print)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + assert replacement.join() == expected_src + + +def test_apply_source_formatting_maintains_with_await_1(): + source = """def func(): + if something: + # Comments are not retrieved for a "new node". Maybe we need a "barely new" check? + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + await print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + source_lines = split_lines(source) + source_tree = ast.parse(source) + source_tree = ast.fix_missing_locations(source_tree) + + context = Context(source, source_tree) + + awaited_print = source_tree + awaited_print.body[0].body[0].body[0] = ast.Expr(ast.Await(source_tree.body[0].body[0].body[0].value)) + + (_, col_offset, _, end_col_offset,) = position_for(source_tree.body[0]) + replacement = split_lines(context.unparse(awaited_print)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + assert replacement.join() == expected_src + + +def test_apply_source_formatting_maintains_with_call(): + source = """def func(): + if something: + # Comments are not retrieved for a "new node". Maybe we need a "barely new" check? + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + call_instead(print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + )) +""" + + source_lines = split_lines(source) + source_tree = ast.parse(source) + source_tree = ast.fix_missing_locations(source_tree) + + context = Context(source, source_tree) + + call_instead_print = source_tree + call = ast.Call(func=ast.Name(id="call_instead"), args=[call_instead_print.body[0].body[0].body[0].value], keywords=[]) + call_instead_print.body[0].body[0].body[0].value = call + + (_, col_offset, _, end_col_offset,) = position_for(source_tree.body[0]) + replacement = split_lines(context.unparse(call_instead_print)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + assert replacement.join() == expected_src + + +def test_apply_source_formatting_maintains_with_call_on_closing_parens(): + source = """def func(): + if something: + # Comments are not retrieved for a "new node". Maybe we need a "barely new" check? + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) # This is mis-aligned +""" + + expected_src = """def func(): + if something: + call_instead(print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + )) # This is mis-aligned +""" + + source_lines = split_lines(source) + source_tree = ast.parse(source) + source_tree = ast.fix_missing_locations(source_tree) + + context = Context(source, source_tree) + + call_instead_print = source_tree + call = ast.Call(func=ast.Name(id="call_instead"), args=[call_instead_print.body[0].body[0].body[0].value], keywords=[]) + call_instead_print.body[0].body[0].body[0].value = call + + (_, col_offset, _, end_col_offset,) = position_for(source_tree.body[0]) + replacement = split_lines(context.unparse(call_instead_print)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + assert replacement.join() == expected_src + + +def test_apply_source_formatting_maintains_with_async(): + source = """def func(): + if something: + # Comments are not retrieved for a "new node". Maybe we need a "barely new" check? + with print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) as p: + do_something() +""" + + expected_src = """def func(): + if something: + async with print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) as p: + do_something() +""" + + source_lines = split_lines(source) + source_tree = ast.parse(source) + source_tree = ast.fix_missing_locations(source_tree) + + context = Context(source, source_tree) + + async_with = source_tree + aw = clone(async_with.body[0].body[0].body[0]) + aw.__class__ = ast.AsyncWith + async_with.body[0].body[0].body[0] = aw + + (_, col_offset, _, end_col_offset,) = position_for(source_tree.body[0]) + replacement = split_lines(context.unparse(async_with)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + assert replacement.join() == expected_src + + +def test_apply_source_formatting_maintains_with_fstring(): + source = ''' +def f(): + return """ +a +""" +''' + + expected_src = ''' +def f(): + return F(""" +a +""") +''' + + source_lines = split_lines(source) + source_tree = ast.parse(source) + source_tree = ast.fix_missing_locations(source_tree) + + context = Context(source, source_tree) + + f_string = source_tree + call = ast.Call(func=ast.Name(id="F"), args=[f_string.body[0].body[0].value], keywords=[]) + f_string.body[0].body[0].value = call + + (_, col_offset, _, end_col_offset,) = position_for(source_tree.body[0]) + replacement = split_lines(context.unparse(f_string)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(1, col_offset, end_col_offset), + ) + # Not sure why there are '\n' mismatches + assert "\n" + replacement.join() == expected_src + + +def test_apply_source_formatting_does_not_with_change(): + source = """def func(): + if something: + # Comments are not retrieved for a "new node". Maybe we need a "barely new" check? + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + await print(call(.1), maybe+something_else_that_is_very_very_very_long, thing . a) +""" + + source_lines = split_lines(source) + source_tree = ast.parse(source) + source_tree = ast.fix_missing_locations(source_tree) + + context = Context(source, source_tree) + + awaited_print = source_tree + del awaited_print.body[0].body[0].body[0].value.args[2] + awaited_print.body[0].body[0].body[0] = ast.Expr(ast.Await(source_tree.body[0].body[0].body[0].value)) + + (_, col_offset, _, end_col_offset,) = position_for(source_tree.body[0]) + replacement = split_lines(context.unparse(awaited_print)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + assert replacement.join() == expected_src diff --git a/tests/test_complete_rules.py b/tests/test_complete_rules.py index 6263267..120a134 100644 --- a/tests/test_complete_rules.py +++ b/tests/test_complete_rules.py @@ -1014,7 +1014,7 @@ def match(self, node): PropagateAndDelete, FoldMyConstants, AtomicTryBlock, - WrapInMultilineFstring, + #WrapInMultilineFstring, ], ) def test_complete_rules(rule, tmp_path): From 8e8025b8abb8c1ca9344a3a59dc29d1f42a4d86c Mon Sep 17 00:00:00 2001 From: memento Date: Wed, 28 Dec 2022 12:36:27 -0600 Subject: [PATCH 6/8] Correcting bad tests and adding the possibility for the replacement to have 1 extra char, like closing parens --- refactor/ast.py | 10 +- tests/test_ast.py | 173 +++++++++++++++++++---------------- tests/test_complete_rules.py | 30 +++--- 3 files changed, 114 insertions(+), 99 deletions(-) diff --git a/refactor/ast.py b/refactor/ast.py index 9035993..1f6c950 100644 --- a/refactor/ast.py +++ b/refactor/ast.py @@ -10,12 +10,10 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from functools import cached_property -from os.path import commonprefix -from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast, List, Tuple - -from refactor.common import find_indent +from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast, Tuple from refactor import common +from refactor.common import find_indent DEFAULT_ENCODING = "utf-8" @@ -50,6 +48,7 @@ def apply_source_formatting( indentation, start_prefix = find_indent(source_lines[markers[0]][:markers[1]]) end_suffix = "" if markers[2] is None else source_lines[-1][markers[2]:] + original_line: str | None for index, line in enumerate(self.data): if index < len(source_lines): original_line = source_lines[index] @@ -58,7 +57,8 @@ def apply_source_formatting( if index == 0: self.data[index] = indentation + str(start_prefix) + str(line) # type: ignore - elif original_line is not None and original_line.startswith(line): + # The updated line can have an extra wrapping in brackets + elif original_line is not None and original_line.startswith(line[:-1]): self.data[index] = line # type: ignore else: self.data[index] = indentation + line # type: ignore diff --git a/tests/test_ast.py b/tests/test_ast.py index a48c344..fd0a3ec 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -298,7 +298,7 @@ def test_precise_unparser_custom_indent_del(): def test_apply_source_formatting_maintains_with_await_0(): source = """def func(): if something: - # Comments are not retrieved for a "new node". Maybe we need a "barely new" check? + # Comments are retrieved print( call(.1), maybe+something_else_that_is_very_very_very_long, @@ -309,6 +309,7 @@ def test_apply_source_formatting_maintains_with_await_0(): expected_src = """def func(): if something: + # Comments are retrieved await print( call(.1), maybe+something_else_that_is_very_very_very_long, @@ -316,29 +317,31 @@ def test_apply_source_formatting_maintains_with_await_0(): thing . a ) """ - - source_lines = split_lines(source) source_tree = ast.parse(source) - source_tree = ast.fix_missing_locations(source_tree) - context = Context(source, source_tree) - awaited_print = source_tree - awaited_print.body[0].body[0].body[0] = ast.Expr(ast.Await(source_tree.body[0].body[0].body[0].value)) + node_to_replace = source_tree.body[0].body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] - (_, col_offset, _, end_col_offset,) = position_for(source_tree.body[0]) - replacement = split_lines(context.unparse(awaited_print)) + new_node = ast.Await(node_to_replace) + replacement = split_lines(context.unparse(new_node)) replacement.apply_source_formatting( source_lines=source_lines, markers=(0, col_offset, end_col_offset), ) - assert replacement.join() == expected_src + lines[view] = replacement + assert lines.join() == expected_src def test_apply_source_formatting_maintains_with_await_1(): source = """def func(): if something: - # Comments are not retrieved for a "new node". Maybe we need a "barely new" check? + # Comments are retrieved print(call(.1), maybe+something_else_that_is_very_very_very_long, maybe / other, @@ -348,35 +351,38 @@ def test_apply_source_formatting_maintains_with_await_1(): expected_src = """def func(): if something: + # Comments are retrieved await print(call(.1), maybe+something_else_that_is_very_very_very_long, maybe / other, thing . a ) """ - - source_lines = split_lines(source) source_tree = ast.parse(source) - source_tree = ast.fix_missing_locations(source_tree) - context = Context(source, source_tree) - awaited_print = source_tree - awaited_print.body[0].body[0].body[0] = ast.Expr(ast.Await(source_tree.body[0].body[0].body[0].value)) + node_to_replace = source_tree.body[0].body[0].body[0].value - (_, col_offset, _, end_col_offset,) = position_for(source_tree.body[0]) - replacement = split_lines(context.unparse(awaited_print)) + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = ast.Await(node_to_replace) + replacement = split_lines(context.unparse(new_node)) replacement.apply_source_formatting( source_lines=source_lines, markers=(0, col_offset, end_col_offset), ) - assert replacement.join() == expected_src + lines[view] = replacement + assert lines.join() == expected_src def test_apply_source_formatting_maintains_with_call(): source = """def func(): if something: - # Comments are not retrieved for a "new node". Maybe we need a "barely new" check? + # Comments are retrieved print(call(.1), maybe+something_else_that_is_very_very_very_long, maybe / other, @@ -386,36 +392,38 @@ def test_apply_source_formatting_maintains_with_call(): expected_src = """def func(): if something: + # Comments are retrieved call_instead(print(call(.1), maybe+something_else_that_is_very_very_very_long, maybe / other, thing . a )) """ - - source_lines = split_lines(source) source_tree = ast.parse(source) - source_tree = ast.fix_missing_locations(source_tree) - context = Context(source, source_tree) - call_instead_print = source_tree - call = ast.Call(func=ast.Name(id="call_instead"), args=[call_instead_print.body[0].body[0].body[0].value], keywords=[]) - call_instead_print.body[0].body[0].body[0].value = call + node_to_replace = source_tree.body[0].body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) - (_, col_offset, _, end_col_offset,) = position_for(source_tree.body[0]) - replacement = split_lines(context.unparse(call_instead_print)) + lines = split_lines(source) + source_lines = lines[view] + + new_node = ast.Call(func=ast.Name(id="call_instead"), args=[node_to_replace], keywords=[]) + replacement = split_lines(context.unparse(new_node)) replacement.apply_source_formatting( source_lines=source_lines, markers=(0, col_offset, end_col_offset), ) - assert replacement.join() == expected_src + lines[view] = replacement + assert lines.join() == expected_src def test_apply_source_formatting_maintains_with_call_on_closing_parens(): source = """def func(): if something: - # Comments are not retrieved for a "new node". Maybe we need a "barely new" check? + # Comments are retrieved print(call(.1), maybe+something_else_that_is_very_very_very_long, maybe / other, @@ -425,36 +433,38 @@ def test_apply_source_formatting_maintains_with_call_on_closing_parens(): expected_src = """def func(): if something: + # Comments are retrieved call_instead(print(call(.1), maybe+something_else_that_is_very_very_very_long, maybe / other, thing . a )) # This is mis-aligned """ - - source_lines = split_lines(source) source_tree = ast.parse(source) - source_tree = ast.fix_missing_locations(source_tree) - context = Context(source, source_tree) - call_instead_print = source_tree - call = ast.Call(func=ast.Name(id="call_instead"), args=[call_instead_print.body[0].body[0].body[0].value], keywords=[]) - call_instead_print.body[0].body[0].body[0].value = call + node_to_replace = source_tree.body[0].body[0].body[0].value - (_, col_offset, _, end_col_offset,) = position_for(source_tree.body[0]) - replacement = split_lines(context.unparse(call_instead_print)) + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = ast.Call(func=ast.Name(id="call_instead"), args=[node_to_replace], keywords=[]) + replacement = split_lines(context.unparse(new_node)) replacement.apply_source_formatting( source_lines=source_lines, markers=(0, col_offset, end_col_offset), ) - assert replacement.join() == expected_src + lines[view] = replacement + assert lines.join() == expected_src def test_apply_source_formatting_maintains_with_async(): source = """def func(): if something: - # Comments are not retrieved for a "new node". Maybe we need a "barely new" check? + # Comments are retrieved with print(call(.1), maybe+something_else_that_is_very_very_very_long, maybe / other, @@ -465,6 +475,7 @@ def test_apply_source_formatting_maintains_with_async(): expected_src = """def func(): if something: + # Comments are retrieved async with print(call(.1), maybe+something_else_that_is_very_very_very_long, maybe / other, @@ -472,66 +483,66 @@ def test_apply_source_formatting_maintains_with_async(): ) as p: do_something() """ - - source_lines = split_lines(source) source_tree = ast.parse(source) - source_tree = ast.fix_missing_locations(source_tree) - context = Context(source, source_tree) - async_with = source_tree - aw = clone(async_with.body[0].body[0].body[0]) - aw.__class__ = ast.AsyncWith - async_with.body[0].body[0].body[0] = aw + node_to_replace = source_tree.body[0].body[0].body[0] - (_, col_offset, _, end_col_offset,) = position_for(source_tree.body[0]) - replacement = split_lines(context.unparse(async_with)) + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = clone(node_to_replace) + new_node.__class__ = ast.AsyncWith + replacement = split_lines(context.unparse(new_node)) replacement.apply_source_formatting( source_lines=source_lines, markers=(0, col_offset, end_col_offset), ) - assert replacement.join() == expected_src + lines[view] = replacement + assert lines.join() == expected_src def test_apply_source_formatting_maintains_with_fstring(): - source = ''' -def f(): + source = '''def f(): return """ a """ ''' - expected_src = ''' -def f(): + expected_src = '''def f(): return F(""" a """) ''' - source_lines = split_lines(source) source_tree = ast.parse(source) - source_tree = ast.fix_missing_locations(source_tree) - context = Context(source, source_tree) - f_string = source_tree - call = ast.Call(func=ast.Name(id="F"), args=[f_string.body[0].body[0].value], keywords=[]) - f_string.body[0].body[0].value = call + node_to_replace = source_tree.body[0].body[0].value - (_, col_offset, _, end_col_offset,) = position_for(source_tree.body[0]) - replacement = split_lines(context.unparse(f_string)) + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = ast.Call(func=ast.Name(id="F"), args=[node_to_replace], keywords=[]) + replacement = split_lines(context.unparse(new_node)) replacement.apply_source_formatting( source_lines=source_lines, - markers=(1, col_offset, end_col_offset), + markers=(0, col_offset, end_col_offset), ) - # Not sure why there are '\n' mismatches - assert "\n" + replacement.join() == expected_src + lines[view] = replacement + assert lines.join() == expected_src def test_apply_source_formatting_does_not_with_change(): source = """def func(): if something: - # Comments are not retrieved for a "new node". Maybe we need a "barely new" check? + # Comments are retrieved print(call(.1), maybe+something_else_that_is_very_very_very_long, maybe / other, @@ -541,23 +552,27 @@ def test_apply_source_formatting_does_not_with_change(): expected_src = """def func(): if something: + # Comments are retrieved await print(call(.1), maybe+something_else_that_is_very_very_very_long, thing . a) """ - source_lines = split_lines(source) source_tree = ast.parse(source) - source_tree = ast.fix_missing_locations(source_tree) - context = Context(source, source_tree) - awaited_print = source_tree - del awaited_print.body[0].body[0].body[0].value.args[2] - awaited_print.body[0].body[0].body[0] = ast.Expr(ast.Await(source_tree.body[0].body[0].body[0].value)) + node_to_replace = source_tree.body[0].body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] - (_, col_offset, _, end_col_offset,) = position_for(source_tree.body[0]) - replacement = split_lines(context.unparse(awaited_print)) + del node_to_replace.args[2] + new_node = ast.Await(node_to_replace) + replacement = split_lines(context.unparse(new_node)) replacement.apply_source_formatting( source_lines=source_lines, markers=(0, col_offset, end_col_offset), ) - assert replacement.join() == expected_src + lines[view] = replacement + assert lines.join() == expected_src diff --git a/tests/test_complete_rules.py b/tests/test_complete_rules.py index 120a134..39389c4 100644 --- a/tests/test_complete_rules.py +++ b/tests/test_complete_rules.py @@ -1000,21 +1000,21 @@ def match(self, node): @pytest.mark.parametrize( "rule", [ - ReplaceNexts, - ReplacePlaceholders, - PropagateConstants, - TypingAutoImporter, - MakeFunctionAsync, - MakeCallAwait, - OnlyKeywordArgumentDefaultNotSetCheckRule, - InternalizeFunctions, - RemoveDeadCode, - RenameImportAndDownstream, - AssertEncoder, - PropagateAndDelete, - FoldMyConstants, - AtomicTryBlock, - #WrapInMultilineFstring, + #ReplaceNexts, + #ReplacePlaceholders, + #PropagateConstants, + #TypingAutoImporter, + #MakeFunctionAsync, + #MakeCallAwait, + #OnlyKeywordArgumentDefaultNotSetCheckRule, + #InternalizeFunctions, + #RemoveDeadCode, + #RenameImportAndDownstream, + #AssertEncoder, + #PropagateAndDelete, + #FoldMyConstants, + #AtomicTryBlock, + WrapInMultilineFstring, ], ) def test_complete_rules(rule, tmp_path): From 8bd25fc1c7b7eab1195b10a08f7129da9734b439 Mon Sep 17 00:00:00 2001 From: memento Date: Thu, 29 Dec 2022 15:20:45 -0600 Subject: [PATCH 7/8] Removing spurious changes (odd indentation) --- refactor/ast.py | 15 ++++++------ tests/test_ast.py | 9 ++++---- tests/test_complete_rules.py | 44 ++++++++++++++++++------------------ 3 files changed, 34 insertions(+), 34 deletions(-) diff --git a/refactor/ast.py b/refactor/ast.py index 1f6c950..c2a783e 100644 --- a/refactor/ast.py +++ b/refactor/ast.py @@ -34,10 +34,10 @@ def join(self) -> str: return "".join(map(str, self.lines)) def apply_source_formatting( - self, - source_lines: Lines, - *, - markers: Tuple[int, int, int | None] = None, + self, + source_lines: Lines, + *, + markers: Tuple[int, int, int | None] = None, ) -> None: """Apply the indentation from source_lines when the first several characters match @@ -62,6 +62,7 @@ def apply_source_formatting( self.data[index] = line # type: ignore else: self.data[index] = indentation + line # type: ignore + if len(self.data) >= 1: self.data[-1] += str(end_suffix) # type: ignore @@ -214,9 +215,9 @@ def maybe_retrieve(self, node: ast.AST) -> bool: @contextmanager def _collect_stmt_comments(self, node: ast.AST) -> Iterator[None]: def _write_if_unseen_comment( - line_no: int, - line: str, - comment_begin: int, + line_no: int, + line: str, + comment_begin: int, ) -> None: if line_no in self._visited_comment_lines: # We have already written this comment as the diff --git a/tests/test_ast.py b/tests/test_ast.py index fd0a3ec..4dc4e82 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -3,13 +3,12 @@ import ast import textwrap import tokenize -from pathlib import Path import pytest -from refactor.common import position_for, clone -from refactor import common, Rule, Session, Replace, Context -from refactor.ast import BaseUnparser, PreciseUnparser, split_lines, DEFAULT_ENCODING +from refactor import common, Context +from refactor.ast import BaseUnparser, PreciseUnparser, split_lines +from refactor.common import position_for, clone def test_split_lines(): @@ -71,7 +70,7 @@ def test_split_lines_with_encoding(case): else: start_line = lines[lineno][col_offset:] end_line = lines[end_lineno][:end_col_offset] - match = start_line + lines[lineno + 1: end_lineno].join() + end_line + match = start_line + lines[lineno + 1 : end_lineno].join() + end_line assert str(match) == ast.get_source_segment(case, node) diff --git a/tests/test_complete_rules.py b/tests/test_complete_rules.py index 39389c4..33c1e5e 100644 --- a/tests/test_complete_rules.py +++ b/tests/test_complete_rules.py @@ -353,7 +353,7 @@ def match(self, node: ast.AST) -> BaseAction | None: assert any(kw_default is None for kw_default in node.args.kw_defaults) if isinstance(node, ast.Lambda) and not ( - isinstance(node.body, ast.Name) and isinstance(node.body.ctx, ast.Load) + isinstance(node.body, ast.Name) and isinstance(node.body.ctx, ast.Load) ): scope = self.context["scope"].resolve(node.body) scope.definitions.get(node.body.id, []) @@ -362,8 +362,8 @@ def match(self, node: ast.AST) -> BaseAction | None: for stmt in node.body: for identifier in ast.walk(stmt): if not ( - isinstance(identifier, ast.Name) - and isinstance(identifier.ctx, ast.Load) + isinstance(identifier, ast.Name) + and isinstance(identifier.ctx, ast.Load) ): continue @@ -629,13 +629,13 @@ class DownstreamAnalyzer(Representative): context_providers = (context.Scope,) def iter_dependents( - self, name: str, source: ast.Import | ast.ImportFrom + self, name: str, source: ast.Import | ast.ImportFrom ) -> Iterator[ast.Name]: for node in ast.walk(self.context.tree): if ( - isinstance(node, ast.Name) - and isinstance(node.ctx, ast.Load) - and node.id == name + isinstance(node, ast.Name) + and isinstance(node.ctx, ast.Load) + and node.id == name ): node_scope = self.context.scope.resolve(node) definitions = node_scope.get_definitions(name) @@ -730,7 +730,7 @@ def match(self, node: ast.AST) -> Iterator[Replace]: [alias] = aliases for dependent in self.context.downstream_analyzer.iter_dependents( - alias.asname or alias.name, node + alias.asname or alias.name, node ): yield Replace(dependent, ast.Name("b", ast.Load())) @@ -1000,20 +1000,20 @@ def match(self, node): @pytest.mark.parametrize( "rule", [ - #ReplaceNexts, - #ReplacePlaceholders, - #PropagateConstants, - #TypingAutoImporter, - #MakeFunctionAsync, - #MakeCallAwait, - #OnlyKeywordArgumentDefaultNotSetCheckRule, - #InternalizeFunctions, - #RemoveDeadCode, - #RenameImportAndDownstream, - #AssertEncoder, - #PropagateAndDelete, - #FoldMyConstants, - #AtomicTryBlock, + ReplaceNexts, + ReplacePlaceholders, + PropagateConstants, + TypingAutoImporter, + MakeFunctionAsync, + MakeCallAwait, + OnlyKeywordArgumentDefaultNotSetCheckRule, + InternalizeFunctions, + RemoveDeadCode, + RenameImportAndDownstream, + AssertEncoder, + PropagateAndDelete, + FoldMyConstants, + AtomicTryBlock, WrapInMultilineFstring, ], ) From 2efaa1e94bc3004cfa6e8836657863a9b7c44a19 Mon Sep 17 00:00:00 2001 From: memento Date: Thu, 29 Dec 2022 15:40:16 -0600 Subject: [PATCH 8/8] Adding specific testcase from Issue #12 --- tests/test_ast.py | 40 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/tests/test_ast.py b/tests/test_ast.py index 4dc4e82..ea6c997 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -460,7 +460,7 @@ def test_apply_source_formatting_maintains_with_call_on_closing_parens(): assert lines.join() == expected_src -def test_apply_source_formatting_maintains_with_async(): +def test_apply_source_formatting_maintains_with_async_complex(): source = """def func(): if something: # Comments are retrieved @@ -504,6 +504,44 @@ def test_apply_source_formatting_maintains_with_async(): assert lines.join() == expected_src +def test_apply_source_formatting_maintains_with_async(): + source = """def func(): + if something: + # Comments are retrieved + with something: # comment2 + a = 1 # Non-standard indent + b = 2 # Non-standard indent becomes standard +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + async with something: + a = 1 + b = 2 # Non-standard indent becomes standard +""" + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0] + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = clone(node_to_replace) + new_node.__class__ = ast.AsyncWith + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + def test_apply_source_formatting_maintains_with_fstring(): source = '''def f(): return """