From ffecfa4c4ad803d44b5fcf98e5d899170a3259e7 Mon Sep 17 00:00:00 2001 From: Swastika Pradhan Date: Sat, 30 Dec 2023 22:19:50 +0530 Subject: [PATCH] Update change.py --- refactor/change.py | 38 +++++++++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/refactor/change.py b/refactor/change.py index 5bc134b..402439a 100644 --- a/refactor/change.py +++ b/refactor/change.py @@ -2,6 +2,7 @@ import difflib import os +import argparse from dataclasses import dataclass from pathlib import Path @@ -40,12 +41,15 @@ def compute_diff(self) -> str: ) ) - def apply_diff(self) -> None: + def apply_diff(self, dry_run: bool = False) -> None: """Apply the transformed version to the bound file.""" - raw_source = self.refactored_source.encode(self.file_info.get_encoding()) - - with open(self.file, "wb") as stream: - stream.write(raw_source) + if dry_run: + diff = self.compute_diff() + print(diff) + else: + raw_source = self.refactored_source.encode(self.file_info.get_encoding()) + with open(self.file, "wb") as stream: + stream.write(raw_source) @property def file(self) -> Path: @@ -53,3 +57,27 @@ def file(self) -> Path: if self.file_info.path is None: raise ValueError("Change expects a valid file") return self.file_info.path + + +def refactor_file(file_path, dry_run=False): + # Perform the refactoring logic here and get the refactored_source + original_source = open(file_path).read() + # Assume refactored_source is obtained somehow in the refactoring process + + change = Change(file_info=_FileInfo(path=Path(file_path)), original_source=original_source, refactored_source=refactored_source) + change.apply_diff(dry_run=dry_run) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Refactor code.") + parser.add_argument("file_path", help="Path to the file to be refactored.") + parser.add_argument("--diff", action="store_true", help="Perform a dry-run and show the diff.") + parser.add_argument("--fail-on-change", action="store_true", help="Exit with 1 if there are any changes without refactoring.") + + args = parser.parse_args() + + refactor_file(args.file_path, dry_run=args.diff) + + if args.fail_on_change and args.diff: + print("Exiting with code 1 due to changes without refactoring.") + exit(1)