diff --git a/refactor/actions.py b/refactor/actions.py index 13e8d76..4905a7a 100644 --- a/refactor/actions.py +++ b/refactor/actions.py @@ -89,11 +89,11 @@ def apply(self, context: Context, source: str) -> str: view = slice(lineno - 1, end_lineno) source_lines = lines[view] - indentation, start_prefix = find_indent(source_lines[0][:col_offset]) - end_suffix = source_lines[-1][end_col_offset:] replacement = split_lines(self._resynthesize(context)) - replacement.apply_indentation( - indentation, start_prefix=start_prefix, end_suffix=end_suffix + # Applies the block indentation only if the replacement lines are different from source + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), ) lines[view] = replacement @@ -168,12 +168,12 @@ class LazyInsertAfter(_LazyActionMixin[ast.stmt, ast.stmt]): def apply(self, context: Context, source: str) -> str: lines = split_lines(source, encoding=context.file_info.get_encoding()) - indentation, start_prefix = find_indent( - lines[self.node.lineno - 1][: self.node.col_offset] - ) replacement = split_lines(context.unparse(self.build())) - replacement.apply_indentation(indentation, start_prefix=start_prefix) + replacement.apply_source_formatting( + source_lines=lines, + markers=(self.node.lineno - 1, self.node.col_offset, None), + ) original_node_end = cast(int, self.node.end_lineno) - 1 if lines[original_node_end].endswith(lines._newline_type): diff --git a/refactor/ast.py b/refactor/ast.py index 3d23262..c2a783e 100644 --- a/refactor/ast.py +++ b/refactor/ast.py @@ -10,9 +10,10 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from functools import cached_property -from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast +from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast, Tuple from refactor import common +from refactor.common import find_indent DEFAULT_ENCODING = "utf-8" @@ -32,19 +33,33 @@ def join(self) -> str: """Return the combined source code.""" return "".join(map(str, self.lines)) - def apply_indentation( + def apply_source_formatting( self, - indentation: StringType, + source_lines: Lines, *, - start_prefix: AnyStringType = "", - end_suffix: AnyStringType = "", + markers: Tuple[int, int, int | None] = None, ) -> None: - """Apply the given indentation, optionally with start and end prefixes - to the bound source lines.""" + """Apply the indentation from source_lines when the first several characters match + :param source_lines: Original lines in source code + :param markers: Indentation and prefix parameters. Tuple of (start line, col_offset, end_suffix | None) + """ + + indentation, start_prefix = find_indent(source_lines[markers[0]][:markers[1]]) + end_suffix = "" if markers[2] is None else source_lines[-1][markers[2]:] + + original_line: str | None for index, line in enumerate(self.data): + if index < len(source_lines): + original_line = source_lines[index] + else: + original_line = None + if index == 0: self.data[index] = indentation + str(start_prefix) + str(line) # type: ignore + # The updated line can have an extra wrapping in brackets + elif original_line is not None and original_line.startswith(line[:-1]): + self.data[index] = line # type: ignore else: self.data[index] = indentation + line # type: ignore @@ -77,7 +92,7 @@ def __getitem__(self, index: SupportsIndex | slice) -> SourceSegment: # re-implements the direct indexing as slicing (e.g. a[1] is a[1:2], with # error handling). direct_index = operator.index(index) - view = raw_line[direct_index : direct_index + 1].decode( + view = raw_line[direct_index: direct_index + 1].decode( encoding=self.encoding ) if not view: diff --git a/tests/test_ast.py b/tests/test_ast.py index 1e5dc1d..ea6c997 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -6,8 +6,9 @@ import pytest -from refactor import common +from refactor import common, Context from refactor.ast import BaseUnparser, PreciseUnparser, split_lines +from refactor.common import position_for, clone def test_split_lines(): @@ -169,6 +170,7 @@ def test_precise_unparser_indented_literals(): """\ def func(): if something: + # On change, comments are removed print( "bleh" "zoom" @@ -240,3 +242,374 @@ def foo(): base = PreciseUnparser(source=source) assert base.unparse(tree) + "\n" == expected_src + + +def test_precise_unparser_custom_indent_no_changes(): + source = """def func(): + if something: + # Arguments have custom indentation + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + # Arguments have custom indentation + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + tree = ast.parse(source) + + base = PreciseUnparser(source=source) + assert base.unparse(tree) + "\n" == expected_src + + +def test_precise_unparser_custom_indent_del(): + source = """def func(): + if something: + # Arguments have custom indentation + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + print(call(.1), maybe+something_else_that_is_very_very_very_long, thing . a) +""" + + tree = ast.parse(source) + del tree.body[0].body[0].body[0].value.args[2] + + base = PreciseUnparser(source=source) + assert base.unparse(tree) + "\n" == expected_src + + +def test_apply_source_formatting_maintains_with_await_0(): + source = """def func(): + if something: + # Comments are retrieved + print( + call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + await print( + call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = ast.Await(node_to_replace) + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_maintains_with_await_1(): + source = """def func(): + if something: + # Comments are retrieved + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + await print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = ast.Await(node_to_replace) + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_maintains_with_call(): + source = """def func(): + if something: + # Comments are retrieved + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + call_instead(print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + )) +""" + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = ast.Call(func=ast.Name(id="call_instead"), args=[node_to_replace], keywords=[]) + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_maintains_with_call_on_closing_parens(): + source = """def func(): + if something: + # Comments are retrieved + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) # This is mis-aligned +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + call_instead(print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + )) # This is mis-aligned +""" + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = ast.Call(func=ast.Name(id="call_instead"), args=[node_to_replace], keywords=[]) + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_maintains_with_async_complex(): + source = """def func(): + if something: + # Comments are retrieved + with print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) as p: + do_something() +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + async with print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) as p: + do_something() +""" + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0] + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = clone(node_to_replace) + new_node.__class__ = ast.AsyncWith + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_maintains_with_async(): + source = """def func(): + if something: + # Comments are retrieved + with something: # comment2 + a = 1 # Non-standard indent + b = 2 # Non-standard indent becomes standard +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + async with something: + a = 1 + b = 2 # Non-standard indent becomes standard +""" + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0] + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = clone(node_to_replace) + new_node.__class__ = ast.AsyncWith + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_maintains_with_fstring(): + source = '''def f(): + return """ +a +""" +''' + + expected_src = '''def f(): + return F(""" +a +""") +''' + + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = ast.Call(func=ast.Name(id="F"), args=[node_to_replace], keywords=[]) + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_does_not_with_change(): + source = """def func(): + if something: + # Comments are retrieved + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + await print(call(.1), maybe+something_else_that_is_very_very_very_long, thing . a) +""" + + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + del node_to_replace.args[2] + new_node = ast.Await(node_to_replace) + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src diff --git a/tests/test_complete_rules.py b/tests/test_complete_rules.py index b46c14e..33c1e5e 100644 --- a/tests/test_complete_rules.py +++ b/tests/test_complete_rules.py @@ -296,6 +296,37 @@ def match(self, node): return AsyncifierAction(node) +class AwaitifierAction(LazyReplace): + def build(self): + if isinstance(self.node, ast.Expr): + self.node.value = ast.Await(self.node.value) + return self.node + if isinstance(self.node, ast.Call): + new_node = ast.Await(self.node) + return new_node + + +class MakeCallAwait(Rule): + INPUT_SOURCE = """ + def somefunc(): + call( + arg0, + arg1) # Intentional mis-alignment + """ + + EXPECTED_SOURCE = """ + def somefunc(): + await call( + arg0, + arg1) # Intentional mis-alignment + """ + + def match(self, node): + assert isinstance(node, ast.Expr) + assert isinstance(node.value, ast.Call) + return AwaitifierAction(node) + + class OnlyKeywordArgumentDefaultNotSetCheckRule(Rule): context_providers = (context.Scope,) @@ -941,6 +972,31 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]: yield InsertAfter(node, remaining_try) +class WrapInMultilineFstring(Rule): + INPUT_SOURCE = ''' +def f(): + return """ +a +""" +''' + EXPECTED_SOURCE = ''' +def f(): + return F(""" +a +""") +''' + + def match(self, node): + assert isinstance(node, ast.Constant) + + # Prevent wrapping F-strings that are already wrapped in F() + # Otherwise you get infinite F(F(F(F(...)))) + parent = self.context.ancestry.get_parent(node) + assert not (isinstance(parent, ast.Call) and isinstance(parent.func, ast.Name) and parent.func.id == 'F') + + return Replace(node, ast.Call(func=ast.Name(id="F"), args=[node], keywords=[])) + + @pytest.mark.parametrize( "rule", [ @@ -949,6 +1005,7 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]: PropagateConstants, TypingAutoImporter, MakeFunctionAsync, + MakeCallAwait, OnlyKeywordArgumentDefaultNotSetCheckRule, InternalizeFunctions, RemoveDeadCode, @@ -957,6 +1014,7 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]: PropagateAndDelete, FoldMyConstants, AtomicTryBlock, + WrapInMultilineFstring, ], ) def test_complete_rules(rule, tmp_path):