diff --git a/refactor/actions.py b/refactor/actions.py index 173a0e4..1812c8a 100644 --- a/refactor/actions.py +++ b/refactor/actions.py @@ -80,13 +80,17 @@ def _replace_input(self, node: ast.AST) -> _LazyActionMixin[K, T]: class _ReplaceCodeSegmentAction(BaseAction): def apply(self, context: Context, source: str) -> str: + # The decorators are removed in the 'lines' but present in the 'context` + # This lead to the 'replacement' containing the decorators and the returned + # 'lines' to duplicate them. Proposed workaround is to add the decorators in + # the 'view', in case the '_resynthesize()' adds/modifies them lines = split_lines(source, encoding=context.file_info.get_encoding()) ( lineno, col_offset, end_lineno, end_col_offset, - ) = self._get_segment_span(context) + ) = self._get_decorated_segment_span(context) view = slice(lineno - 1, end_lineno) source_lines = lines[view] @@ -104,6 +108,9 @@ def apply(self, context: Context, source: str) -> str: def _get_segment_span(self, context: Context) -> PositionType: raise NotImplementedError + def _get_decorated_segment_span(self, context: Context) -> PositionType: + raise NotImplementedError + def _resynthesize(self, context: Context) -> str: raise NotImplementedError @@ -123,6 +130,13 @@ class LazyReplace(_ReplaceCodeSegmentAction, _LazyActionMixin[ast.AST, ast.AST]) def _get_segment_span(self, context: Context) -> PositionType: return position_for(self.node) + def _get_decorated_segment_span(self, context: Context) -> PositionType: + lineno, col_offset, end_lineno, end_col_offset = position_for(self.node) + # Add the decorators to the segment span to resolve an issue with def -> async def + if hasattr(self.node, "decorator_list") and len(getattr(self.node, "decorator_list")) > 0: + lineno, _, _, _ = position_for(getattr(self.node, "decorator_list")[0]) + return lineno, col_offset, end_lineno, end_col_offset + def _resynthesize(self, context: Context) -> str: return context.unparse(self.build()) @@ -221,6 +235,9 @@ def apply(self, context: Context, source: str) -> str: replacement[-1] += lines._newline_type original_node_start = cast(int, self.node.lineno) + if hasattr(self.node, "decorator_list") and len(getattr(self.node, "decorator_list")) > 0: + original_node_start, _, _, _ = position_for(getattr(self.node, "decorator_list")[0]) + for line in reversed(replacement): lines.insert(original_node_start - 1, line) @@ -290,6 +307,9 @@ class _Rename(Replace): def _get_segment_span(self, context: Context) -> PositionType: return self.identifier_span + def _get_decorated_segment_span(self, context: Context) -> PositionType: + return self.identifier_span + def _resynthesize(self, context: Context) -> str: return self.target.name @@ -322,6 +342,13 @@ def is_critical_node(self, context: Context) -> bool: def _get_segment_span(self, context: Context) -> PositionType: return position_for(self.node) + def _get_decorated_segment_span(self, context: Context) -> PositionType: + lineno, col_offset, end_lineno, end_col_offset = position_for(self.node) + # Add the decorators to the segment span to resolve an issue with def -> async def + if hasattr(self.node, "decorator_list") and len(getattr(self.node, "decorator_list")) > 0: + lineno, _, _, _ = position_for(getattr(self.node, "decorator_list")[0]) + return lineno, col_offset, end_lineno, end_col_offset + def _resynthesize(self, context: Context) -> str: if self.is_critical_node(context): raise InvalidActionError( diff --git a/tests/test_actions.py b/tests/test_actions.py index a8040be..fb977df 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -10,6 +10,7 @@ from refactor import Session, common from refactor.actions import Erase, InvalidActionError, InsertAfter, Replace, InsertBefore +from refactor.common import clone from refactor.context import Context from refactor.core import Rule @@ -48,60 +49,92 @@ def foo(): INVALID_ERASES_TREE = ast.parse(INVALID_ERASES) +class TestInsertBeforeDecoratedFunction(Rule): + INPUT_SOURCE = """ + @decorate + def decorated(): + test_this()""" + + EXPECTED_SOURCE = """ + await async_test() + @decorate + async def decorated(): + test_this()""" + + def match(self, node: ast.AST) -> Iterator[InsertBefore]: + assert isinstance(node, ast.FunctionDef) + + await_st = ast.parse("await async_test()") + yield InsertBefore(node, cast(ast.stmt, await_st)) + new_node = clone(node) + new_node.__class__ = ast.AsyncFunctionDef + yield Replace(node, new_node) + + +class TestInsertBeforeMultipleDecorators(Rule): + INPUT_SOURCE = """ + @decorate0 + @decorate1 + @decorate2 + def decorated(): + test_this()""" + + EXPECTED_SOURCE = """ + await async_test() + @decorate0 + @decorate1 + @decorate2 + async def decorated(): + test_this()""" + + def match(self, node: ast.AST) -> Iterator[InsertBefore]: + assert isinstance(node, ast.FunctionDef) + + await_st = ast.parse("await async_test()") + yield InsertBefore(node, cast(ast.stmt, await_st)) + new_node = clone(node) + new_node.__class__ = ast.AsyncFunctionDef + yield Replace(node, new_node) + + class TestInsertAfterBottom(Rule): INPUT_SOURCE = """ - try: - base_tree = get_tree(base_file, module_name) - first_tree = get_tree(first_tree, module_name) - second_tree = get_tree(second_tree, module_name) - third_tree = get_tree(third_tree, module_name) - except (SyntaxError, FileNotFoundError): - continue""" + def undecorated(): + test_this()""" EXPECTED_SOURCE = """ - try: - base_tree = get_tree(base_file, module_name) - except (SyntaxError, FileNotFoundError): - continue + async def undecorated(): + test_this() await async_test()""" def match(self, node: ast.AST) -> Iterator[InsertAfter]: - assert isinstance(node, ast.Try) - assert len(node.body) >= 2 + assert isinstance(node, ast.FunctionDef) await_st = ast.parse("await async_test()") yield InsertAfter(node, cast(ast.stmt, await_st)) - new_try = common.clone(node) - new_try.body = [node.body[0]] - yield Replace(node, cast(ast.AST, new_try)) + new_node = clone(node) + new_node.__class__ = ast.AsyncFunctionDef + yield Replace(node, new_node) class TestInsertBeforeTop(Rule): INPUT_SOURCE = """ - try: - base_tree = get_tree(base_file, module_name) - first_tree = get_tree(first_tree, module_name) - second_tree = get_tree(second_tree, module_name) - third_tree = get_tree(third_tree, module_name) - except (SyntaxError, FileNotFoundError): - continue""" + def undecorated(): + test_this()""" EXPECTED_SOURCE = """ await async_test() - try: - base_tree = get_tree(base_file, module_name) - except (SyntaxError, FileNotFoundError): - continue""" + async def undecorated(): + test_this()""" def match(self, node: ast.AST) -> Iterator[InsertBefore]: - assert isinstance(node, ast.Try) - assert len(node.body) >= 2 + assert isinstance(node, ast.FunctionDef) await_st = ast.parse("await async_test()") yield InsertBefore(node, cast(ast.stmt, await_st)) - new_try = common.clone(node) - new_try.body = [node.body[0]] - yield Replace(node, cast(ast.AST, new_try)) + new_node = clone(node) + new_node.__class__ = ast.AsyncFunctionDef + yield Replace(node, new_node) class TestInsertAfter(Rule): @@ -488,6 +521,8 @@ def test_erase_invalid(invalid_node): @pytest.mark.parametrize( "rule", [ + TestInsertBeforeDecoratedFunction, + TestInsertBeforeMultipleDecorators, TestInsertAfterBottom, TestInsertBeforeTop, TestInsertAfter, diff --git a/tests/test_common.py b/tests/test_common.py index 5ff4b9d..49b5085 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -137,6 +137,24 @@ def func(): assert position_for(right_node) == (3, 23, 3, 25) +def test_get_positions_with_decorator(): + source = textwrap.dedent( + """\ + @deco0 + @deco1(arg0, + arg1) + def func(): + if a > 5: + return 5 + 3 + 25 + elif b > 10: + return 1 + 3 + 5 + 7 + """ + ) + tree = ast.parse(source) + right_node = tree.body[0].body[0].body[0].value.right + assert position_for(right_node) == (6, 23, 6, 25) + + def test_singleton(): from dataclasses import dataclass diff --git a/tests/test_complete_rules.py b/tests/test_complete_rules.py index dfd4984..8b299be 100644 --- a/tests/test_complete_rules.py +++ b/tests/test_complete_rules.py @@ -371,6 +371,48 @@ def match(self, node): return AsyncifierAction(node) +class MakeFunctionAsyncWithDecorators(Rule): + INPUT_SOURCE = """ + @deco0 + @deco1(arg0, + arg1) + def something(): + a += .1 + '''you know + this is custom + literal + ''' + print(we, + preserve, + everything + ) + return ( + right + "?") + """ + + EXPECTED_SOURCE = """ + @deco0 + @deco1(arg0, + arg1) + async def something(): + a += .1 + '''you know + this is custom + literal + ''' + print(we, + preserve, + everything + ) + return ( + right + "?") + """ + + def match(self, node): + assert isinstance(node, ast.FunctionDef) + return AsyncifierAction(node) + + class OnlyKeywordArgumentDefaultNotSetCheckRule(Rule): context_providers = (context.Scope,)