Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion refactor/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,5 +255,57 @@ def retrieve_segment(self, node: ast.AST, segment: str) -> None:
self.fill()
self.write(segment)

class PreciseEmptyLinesUnparser(PreciseUnparser):
"""A more precise version of the default unparser,
with various improvements such as comment handling
for major statements and child node recovery."""

@contextmanager
def _collect_stmt_comments(self, node: ast.AST) -> Iterator[None]:
def _write_if_unseen_comment(
line_no: int,
line: str,
comment_begin: int,
) -> None:
if line_no in self._visited_comment_lines:
# We have already written this comment as the
# end of another node. No need to re-write it.
return

self.fill()
self.write(line[comment_begin:])
self._visited_comment_lines.add(line_no)

assert self.source is not None
lines = self.source.splitlines()
node_start, node_end = node.lineno - 1, cast(int, node.end_lineno)

# Collect comments in the reverse order, so we can properly
# identify the end of the current comment block.
preceding_comments = []
for offset, line in enumerate(reversed(lines[:node_start])):
comment_begin = line.find("#")
if (line and not line.isspace()) and (comment_begin == -1 or comment_begin != node.col_offset):
break

preceding_comments.append((node_start - offset, line, node.col_offset))

for comment_info in reversed(preceding_comments):
_write_if_unseen_comment(*comment_info)

yield

for offset, line in enumerate(lines[node_end:], 1):
comment_begin = line.find("#")
if comment_begin == -1 or comment_begin != node.col_offset:
break

_write_if_unseen_comment(
line_no=node_end + offset,
line=line,
comment_begin=node.col_offset,
)



UNPARSER_BACKENDS = {"fast": BaseUnparser, "precise": PreciseUnparser}
UNPARSER_BACKENDS = {"fast": BaseUnparser, "precise": PreciseUnparser, "precise_with_empty_lines": PreciseEmptyLinesUnparser}
112 changes: 111 additions & 1 deletion tests/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

from refactor import common
from refactor.ast import BaseUnparser, PreciseUnparser, split_lines
from refactor.ast import BaseUnparser, PreciseUnparser, split_lines, PreciseEmptyLinesUnparser


def test_split_lines():
Expand Down Expand Up @@ -240,3 +240,113 @@ def foo():

base = PreciseUnparser(source=source)
assert base.unparse(tree) + "\n" == expected_src



def test_precise_empty_lines_unparser():
source = textwrap.dedent(
"""\
def func():
if something:
print(
call(.1),
maybe+something_else,
maybe / other,
thing . a
)
"""
)

expected_src = textwrap.dedent(
"""\
def func():
if something:
print(call(.1), maybe+something_else, maybe / other, thing . a, 3)
"""
)

tree = ast.parse(source)
tree.body[0].body[0].body[0].value.args.append(ast.Constant(3))

base = PreciseEmptyLinesUnparser(source=source)
assert base.unparse(tree) + "\n" == expected_src


def test_precise_empty_lines_unparser_indented_literals():
source = textwrap.dedent(
"""\
def func():
if something:
print(
"bleh"
"zoom"
)
"""
)

expected_src = textwrap.dedent(
"""\
def func():
if something:
print("bleh"
"zoom", 3)
"""
)

tree = ast.parse(source)
tree.body[0].body[0].body[0].value.args.append(ast.Constant(3))

base = PreciseEmptyLinesUnparser(source=source)
assert base.unparse(tree) + "\n" == expected_src

def test_precise_empty_lines_unparser_comments():
source = textwrap.dedent(
"""\
def foo():
# unindented comment
# indented but not connected comment

# a
# a1
print()
# a2
print()
# b

# b2
print(
c # e
)
# c
print(d)
# final comment
"""
)

expected_src = (
"""\
def foo():
# indented but not connected comment

# a
# a1
print()
# a2
print()
# b

# b2
print(
c # e
)
# c
"""
)

tree = ast.parse(source)

# # Remove the print(d)
tree.body[0].body.pop()

base = PreciseEmptyLinesUnparser(source=source)
assert base.unparse(tree) + "\n" == expected_src