Skip to content
16 changes: 8 additions & 8 deletions refactor/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ def apply(self, context: Context, source: str) -> str:
view = slice(lineno - 1, end_lineno)
source_lines = lines[view]

indentation, start_prefix = find_indent(source_lines[0][:col_offset])
end_suffix = source_lines[-1][end_col_offset:]
replacement = split_lines(self._resynthesize(context))
replacement.apply_indentation(
indentation, start_prefix=start_prefix, end_suffix=end_suffix
# Applies the block indentation only if the replacement lines are different from source
replacement.apply_source_formatting(
source_lines=source_lines,
markers=(0, col_offset, end_col_offset),
)

lines[view] = replacement
Expand Down Expand Up @@ -168,12 +168,12 @@ class LazyInsertAfter(_LazyActionMixin[ast.stmt, ast.stmt]):

def apply(self, context: Context, source: str) -> str:
lines = split_lines(source, encoding=context.file_info.get_encoding())
indentation, start_prefix = find_indent(
lines[self.node.lineno - 1][: self.node.col_offset]
)

replacement = split_lines(context.unparse(self.build()))
replacement.apply_indentation(indentation, start_prefix=start_prefix)
replacement.apply_source_formatting(
source_lines=lines,
markers=(self.node.lineno - 1, self.node.col_offset, None),
)

original_node_end = cast(int, self.node.end_lineno) - 1
if lines[original_node_end].endswith(lines._newline_type):
Expand Down
28 changes: 21 additions & 7 deletions refactor/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from functools import cached_property
from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast
from os.path import commonprefix
from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast, List, Tuple

from refactor.common import find_indent

from refactor import common

Expand All @@ -32,17 +35,28 @@ def join(self) -> str:
"""Return the combined source code."""
return "".join(map(str, self.lines))

def apply_indentation(
def apply_source_formatting(
self,
indentation: StringType,
source_lines: Lines,
*,
start_prefix: AnyStringType = "",
end_suffix: AnyStringType = "",
markers: Tuple[int, int, int | None] = None,
) -> None:
"""Apply the given indentation, optionally with start and end prefixes
to the bound source lines."""
"""Apply the indentation from source_lines when the first several characters match

:param source_lines: Original lines in source code
:param markers: Indentation and prefix parameters. Tuple of (start line, col_offset, end_suffix | None)
"""

def not_original(i: int) -> bool:
common_chars: str = commonprefix([str(self.data[i]), str(source_lines.data[i].data)])
is_multiline_string: int = str(self.data[i]).find(common_chars) == 0 and common_chars in ["'''", '"""']
return not (i < len(source_lines.data) and (str(self.data[i]) == common_chars or is_multiline_string))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if we reduce it to bare minimum (let's drop the multiline strings for now, just to make this patch simpler), what this does is basically checking if the original line starts with the replaced line and if it is the case we skip indenting?

A better way to express it might be this:

        for index, line in enumerate(self.data):
            if index < len(source_lines):
                original_line = source_lines[index]
            else:
                original_line = None

            if index == 0:
                self.data[index] = indentation + str(start_prefix) + str(line)  # type: ignore
            elif original_line is not None and original_line.startswith(line):
                self.data[index] = line # type: ignore
            else:
                self.data[index] = indentation + line  # type: ignore

But I am still unsure on how this should work in different scenarios, would love to see a few more examples (or actually tests) where the indentation is preserved for the first argument (when replacing the whole call), the middle arguments, the last argument (maybe introduce a new argument or remove one from the middle).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It only works when the Call is unchanged, removing or adding arguments does not maintain the original form, but collapses the Call into one line. The case I was working with was wrapping the Call in Await and it works well in that case

There is something I don't understand though, I added a few tests (bear with my padawan-level) and the Issue #12 pass, but in the complete_rules it does not, it'd be great if you took a look at those tests, I must be doing something (or more) incorrectly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, the issue was that in the test_ast.py, I use the full source_lines (with indent=0 and thus is not added to the failing line), but in the test_complete_rules.py only the changed source line is passed (because of the view), so the indentation is there and the test fails.
Now, the 'multi-line' issue is more of a "within parens" issue, so would it be acceptable to just have this:
elif original_line is not None and (original_line.startswith(line) or line[-1] in (")}]")):
This way the test passes


indentation, start_prefix = find_indent(source_lines[markers[0]][:markers[1]])
end_suffix = "" if markers[2] is None else source_lines[-1][markers[2]:]

for index, line in enumerate(self.data):
indentation = indentation if not_original(index) else ""
if index == 0:
self.data[index] = indentation + str(start_prefix) + str(line) # type: ignore
else:
Expand Down
74 changes: 66 additions & 8 deletions tests/test_complete_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,37 @@ def match(self, node):
return AsyncifierAction(node)


class AwaitifierAction(LazyReplace):
def build(self):
if isinstance(self.node, ast.Expr):
self.node.value = ast.Await(self.node.value)
return self.node
if isinstance(self.node, ast.Call):
new_node = ast.Await(self.node)
return new_node


class MakeCallAwait(Rule):
INPUT_SOURCE = """
def somefunc():
call(
arg0,
arg1) # Intentional mis-alignment
"""

EXPECTED_SOURCE = """
def somefunc():
await call(
arg0,
arg1) # Intentional mis-alignment
"""

def match(self, node):
assert isinstance(node, ast.Expr)
assert isinstance(node.value, ast.Call)
return AwaitifierAction(node)


class OnlyKeywordArgumentDefaultNotSetCheckRule(Rule):
context_providers = (context.Scope,)

Expand All @@ -322,7 +353,7 @@ def match(self, node: ast.AST) -> BaseAction | None:
assert any(kw_default is None for kw_default in node.args.kw_defaults)

if isinstance(node, ast.Lambda) and not (
isinstance(node.body, ast.Name) and isinstance(node.body.ctx, ast.Load)
isinstance(node.body, ast.Name) and isinstance(node.body.ctx, ast.Load)
):
scope = self.context["scope"].resolve(node.body)
scope.definitions.get(node.body.id, [])
Expand All @@ -331,8 +362,8 @@ def match(self, node: ast.AST) -> BaseAction | None:
for stmt in node.body:
for identifier in ast.walk(stmt):
if not (
isinstance(identifier, ast.Name)
and isinstance(identifier.ctx, ast.Load)
isinstance(identifier, ast.Name)
and isinstance(identifier.ctx, ast.Load)
):
continue

Expand Down Expand Up @@ -598,13 +629,13 @@ class DownstreamAnalyzer(Representative):
context_providers = (context.Scope,)

def iter_dependents(
self, name: str, source: ast.Import | ast.ImportFrom
self, name: str, source: ast.Import | ast.ImportFrom
) -> Iterator[ast.Name]:
for node in ast.walk(self.context.tree):
if (
isinstance(node, ast.Name)
and isinstance(node.ctx, ast.Load)
and node.id == name
isinstance(node, ast.Name)
and isinstance(node.ctx, ast.Load)
and node.id == name
):
node_scope = self.context.scope.resolve(node)
definitions = node_scope.get_definitions(name)
Expand Down Expand Up @@ -699,7 +730,7 @@ def match(self, node: ast.AST) -> Iterator[Replace]:

[alias] = aliases
for dependent in self.context.downstream_analyzer.iter_dependents(
alias.asname or alias.name, node
alias.asname or alias.name, node
):
yield Replace(dependent, ast.Name("b", ast.Load()))

Expand Down Expand Up @@ -941,6 +972,31 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]:
yield InsertAfter(node, remaining_try)


class WrapInMultilineFstring(Rule):
INPUT_SOURCE = '''
def f():
return """
a
"""
'''
EXPECTED_SOURCE = '''
def f():
return F("""
a
""")
'''

def match(self, node):
assert isinstance(node, ast.Constant)

# Prevent wrapping F-strings that are already wrapped in F()
# Otherwise you get infinite F(F(F(F(...))))
parent = self.context.ancestry.get_parent(node)
assert not (isinstance(parent, ast.Call) and isinstance(parent.func, ast.Name) and parent.func.id == 'F')

return Replace(node, ast.Call(func=ast.Name(id="F"), args=[node], keywords=[]))


@pytest.mark.parametrize(
"rule",
[
Expand All @@ -949,6 +1005,7 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]:
PropagateConstants,
TypingAutoImporter,
MakeFunctionAsync,
MakeCallAwait,
OnlyKeywordArgumentDefaultNotSetCheckRule,
InternalizeFunctions,
RemoveDeadCode,
Expand All @@ -957,6 +1014,7 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]:
PropagateAndDelete,
FoldMyConstants,
AtomicTryBlock,
WrapInMultilineFstring,
],
)
def test_complete_rules(rule, tmp_path):
Expand Down