Skip to content

Fix/insert before decorated #80

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion refactor/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,17 @@ def _replace_input(self, node: ast.AST) -> _LazyActionMixin[K, T]:

class _ReplaceCodeSegmentAction(BaseAction):
def apply(self, context: Context, source: str) -> str:
# The decorators are removed in the 'lines' but present in the 'context`
# This lead to the 'replacement' containing the decorators and the returned
# 'lines' to duplicate them. Proposed workaround is to add the decorators in
# the 'view', in case the '_resynthesize()' adds/modifies them
lines = split_lines(source, encoding=context.file_info.get_encoding())
(
lineno,
col_offset,
end_lineno,
end_col_offset,
) = self._get_segment_span(context)
) = self._get_decorated_segment_span(context)

view = slice(lineno - 1, end_lineno)
source_lines = lines[view]
Expand All @@ -104,6 +108,9 @@ def apply(self, context: Context, source: str) -> str:
def _get_segment_span(self, context: Context) -> PositionType:
raise NotImplementedError

def _get_decorated_segment_span(self, context: Context) -> PositionType:
raise NotImplementedError

def _resynthesize(self, context: Context) -> str:
raise NotImplementedError

Expand All @@ -123,6 +130,13 @@ class LazyReplace(_ReplaceCodeSegmentAction, _LazyActionMixin[ast.AST, ast.AST])
def _get_segment_span(self, context: Context) -> PositionType:
return position_for(self.node)

def _get_decorated_segment_span(self, context: Context) -> PositionType:
lineno, col_offset, end_lineno, end_col_offset = position_for(self.node)
# Add the decorators to the segment span to resolve an issue with def -> async def
if hasattr(self.node, "decorator_list") and len(getattr(self.node, "decorator_list")) > 0:
lineno, _, _, _ = position_for(getattr(self.node, "decorator_list")[0])
return lineno, col_offset, end_lineno, end_col_offset

def _resynthesize(self, context: Context) -> str:
return context.unparse(self.build())

Expand Down Expand Up @@ -221,6 +235,9 @@ def apply(self, context: Context, source: str) -> str:
replacement[-1] += lines._newline_type

original_node_start = cast(int, self.node.lineno)
if hasattr(self.node, "decorator_list") and len(getattr(self.node, "decorator_list")) > 0:
original_node_start, _, _, _ = position_for(getattr(self.node, "decorator_list")[0])

for line in reversed(replacement):
lines.insert(original_node_start - 1, line)

Expand Down Expand Up @@ -290,6 +307,9 @@ class _Rename(Replace):
def _get_segment_span(self, context: Context) -> PositionType:
return self.identifier_span

def _get_decorated_segment_span(self, context: Context) -> PositionType:
return self.identifier_span

def _resynthesize(self, context: Context) -> str:
return self.target.name

Expand Down Expand Up @@ -322,6 +342,13 @@ def is_critical_node(self, context: Context) -> bool:
def _get_segment_span(self, context: Context) -> PositionType:
return position_for(self.node)

def _get_decorated_segment_span(self, context: Context) -> PositionType:
lineno, col_offset, end_lineno, end_col_offset = position_for(self.node)
# Add the decorators to the segment span to resolve an issue with def -> async def
if hasattr(self.node, "decorator_list") and len(getattr(self.node, "decorator_list")) > 0:
lineno, _, _, _ = position_for(getattr(self.node, "decorator_list")[0])
return lineno, col_offset, end_lineno, end_col_offset

def _resynthesize(self, context: Context) -> str:
if self.is_critical_node(context):
raise InvalidActionError(
Expand Down
99 changes: 67 additions & 32 deletions tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from refactor import Session, common
from refactor.actions import Erase, InvalidActionError, InsertAfter, Replace, InsertBefore
from refactor.common import clone
from refactor.context import Context
from refactor.core import Rule

Expand Down Expand Up @@ -48,60 +49,92 @@ def foo():
INVALID_ERASES_TREE = ast.parse(INVALID_ERASES)


class TestInsertBeforeDecoratedFunction(Rule):
INPUT_SOURCE = """
@decorate
def decorated():
test_this()"""

EXPECTED_SOURCE = """
await async_test()
@decorate
async def decorated():
test_this()"""

def match(self, node: ast.AST) -> Iterator[InsertBefore]:
assert isinstance(node, ast.FunctionDef)

await_st = ast.parse("await async_test()")
yield InsertBefore(node, cast(ast.stmt, await_st))
new_node = clone(node)
new_node.__class__ = ast.AsyncFunctionDef
yield Replace(node, new_node)


class TestInsertBeforeMultipleDecorators(Rule):
INPUT_SOURCE = """
@decorate0
@decorate1
@decorate2
def decorated():
test_this()"""

EXPECTED_SOURCE = """
await async_test()
@decorate0
@decorate1
@decorate2
async def decorated():
test_this()"""

def match(self, node: ast.AST) -> Iterator[InsertBefore]:
assert isinstance(node, ast.FunctionDef)

await_st = ast.parse("await async_test()")
yield InsertBefore(node, cast(ast.stmt, await_st))
new_node = clone(node)
new_node.__class__ = ast.AsyncFunctionDef
yield Replace(node, new_node)


class TestInsertAfterBottom(Rule):
INPUT_SOURCE = """
try:
base_tree = get_tree(base_file, module_name)
first_tree = get_tree(first_tree, module_name)
second_tree = get_tree(second_tree, module_name)
third_tree = get_tree(third_tree, module_name)
except (SyntaxError, FileNotFoundError):
continue"""
def undecorated():
test_this()"""

EXPECTED_SOURCE = """
try:
base_tree = get_tree(base_file, module_name)
except (SyntaxError, FileNotFoundError):
continue
async def undecorated():
test_this()
await async_test()"""

def match(self, node: ast.AST) -> Iterator[InsertAfter]:
assert isinstance(node, ast.Try)
assert len(node.body) >= 2
assert isinstance(node, ast.FunctionDef)

await_st = ast.parse("await async_test()")
yield InsertAfter(node, cast(ast.stmt, await_st))
new_try = common.clone(node)
new_try.body = [node.body[0]]
yield Replace(node, cast(ast.AST, new_try))
new_node = clone(node)
new_node.__class__ = ast.AsyncFunctionDef
yield Replace(node, new_node)


class TestInsertBeforeTop(Rule):
INPUT_SOURCE = """
try:
base_tree = get_tree(base_file, module_name)
first_tree = get_tree(first_tree, module_name)
second_tree = get_tree(second_tree, module_name)
third_tree = get_tree(third_tree, module_name)
except (SyntaxError, FileNotFoundError):
continue"""
def undecorated():
test_this()"""

EXPECTED_SOURCE = """
await async_test()
try:
base_tree = get_tree(base_file, module_name)
except (SyntaxError, FileNotFoundError):
continue"""
async def undecorated():
test_this()"""

def match(self, node: ast.AST) -> Iterator[InsertBefore]:
assert isinstance(node, ast.Try)
assert len(node.body) >= 2
assert isinstance(node, ast.FunctionDef)

await_st = ast.parse("await async_test()")
yield InsertBefore(node, cast(ast.stmt, await_st))
new_try = common.clone(node)
new_try.body = [node.body[0]]
yield Replace(node, cast(ast.AST, new_try))
new_node = clone(node)
new_node.__class__ = ast.AsyncFunctionDef
yield Replace(node, new_node)


class TestInsertAfter(Rule):
Expand Down Expand Up @@ -488,6 +521,8 @@ def test_erase_invalid(invalid_node):
@pytest.mark.parametrize(
"rule",
[
TestInsertBeforeDecoratedFunction,
TestInsertBeforeMultipleDecorators,
TestInsertAfterBottom,
TestInsertBeforeTop,
TestInsertAfter,
Expand Down
18 changes: 18 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,24 @@ def func():
assert position_for(right_node) == (3, 23, 3, 25)


def test_get_positions_with_decorator():
source = textwrap.dedent(
"""\
@deco0
@deco1(arg0,
arg1)
def func():
if a > 5:
return 5 + 3 + 25
elif b > 10:
return 1 + 3 + 5 + 7
"""
)
tree = ast.parse(source)
right_node = tree.body[0].body[0].body[0].value.right
assert position_for(right_node) == (6, 23, 6, 25)


def test_singleton():
from dataclasses import dataclass

Expand Down
42 changes: 42 additions & 0 deletions tests/test_complete_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,48 @@ def match(self, node):
return AsyncifierAction(node)


class MakeFunctionAsyncWithDecorators(Rule):
INPUT_SOURCE = """
@deco0
@deco1(arg0,
arg1)
def something():
a += .1
'''you know
this is custom
literal
'''
print(we,
preserve,
everything
)
return (
right + "?")
"""

EXPECTED_SOURCE = """
@deco0
@deco1(arg0,
arg1)
async def something():
a += .1
'''you know
this is custom
literal
'''
print(we,
preserve,
everything
)
return (
right + "?")
"""

def match(self, node):
assert isinstance(node, ast.FunctionDef)
return AsyncifierAction(node)


class OnlyKeywordArgumentDefaultNotSetCheckRule(Rule):
context_providers = (context.Scope,)

Expand Down