From dc0c6de0f56d573dd0ef2c4f7a01c6a6709ad49f Mon Sep 17 00:00:00 2001 From: Ananya Nayak Date: Sat, 3 Jun 2023 19:11:36 +0530 Subject: [PATCH 1/6] chore: improved list-feeds command --- gtfs/__main__.py | 99 ++++++++++++++++++---------- gtfs/feed_sources/__init__.py | 2 +- gtfs/utils/check_status.py | 0 gtfs/utils/constants.py | 7 -- gtfs/utils/extend_effective_dates.py | 0 pyproject.toml | 1 + tests/test_cli.py | 15 +++-- 7 files changed, 77 insertions(+), 47 deletions(-) mode change 100755 => 100644 gtfs/utils/check_status.py mode change 100755 => 100644 gtfs/utils/extend_effective_dates.py diff --git a/gtfs/__main__.py b/gtfs/__main__.py index 53f9069..5a8de4f 100755 --- a/gtfs/__main__.py +++ b/gtfs/__main__.py @@ -1,14 +1,15 @@ #!/usr/bin/env python """Command line interface for fetching GTFS.""" import logging +from typing import Union import typer from prettytable.colortable import ColorTable, Themes from typing_extensions import Annotated from .feed_source import FeedSource -from .feed_sources import __all__ as feed_sources -from .utils.constants import Predicate, console, spinner, success +from .feed_sources import feed_sources +from .utils.constants import Predicate, spinner from .utils.geom import Bbox, bbox_contains_bbox, bbox_intersects_bbox logging.basicConfig() @@ -16,7 +17,9 @@ app = typer.Typer() -def check_bbox(bbox: str) -> Bbox: +def check_bbox(bbox: str) -> Union[Bbox, None]: + if bbox is None: + return try: min_x, min_y, max_x, max_y = [float(coord) for coord in bbox.split(",")] except ValueError as e: @@ -39,50 +42,76 @@ def check_bbox(bbox: str) -> Bbox: @app.command() def list_feeds( bbox: Annotated[ - str, + Union[str, None], typer.Option( "--bbox", "-b", help="pass value as a string separated by commas like this: min_x,min_y,max_x,max_y", callback=check_bbox, ), - ], + ] = None, predicate: Annotated[ - Predicate, + Union[Predicate, None], typer.Option( "--predicate", - "-p", + "-pd", help="the gtfs feed should intersect or should be contained inside the user's bbox", ), - ] = Predicate.intersects, + ] = None, + pretty: Annotated[ + bool, + typer.Option( + "--pretty", + "-pt", + help="display feeds inside a pretty table", + ), + ] = False, ) -> None: """Filter feeds spatially based on bounding box.""" - spinner("Fetching feeds...", 1) - ptable = ColorTable(["Feed Source", "Transit URL", "Bounding Box"], theme=Themes.OCEAN) - for src in feed_sources: - feed_bbox: Bbox = src.bbox - if predicate == "contains": - if not bbox_contains_bbox(feed_bbox, bbox): - continue - elif predicate == "intersects": - if (not bbox_intersects_bbox(feed_bbox, bbox)) and ( - not bbox_intersects_bbox(bbox, feed_bbox) - ): - continue - - row = [ - src.__name__, - src.url, - [feed_bbox.min_x, feed_bbox.min_y, feed_bbox.max_x, feed_bbox.max_y], - ] - ptable.add_row(row) - - print( - "\n" + f"Feeds based on bbox input {bbox} and " - f"for predicate={predicate.value} are as follows:" - ) - print("\n" + ptable.get_string()) - console.print("All done!", style=success) + if bbox is None and predicate is not None: + raise typer.BadParameter( + f"Please pass a bbox if you want to filter feeds spatially based on predicate = {predicate}!" + ) + elif bbox is not None and predicate is None: + raise typer.BadParameter( + f"Please pass a predicate if you want to filter feeds spatially based on bbox = {bbox}!" + ) + else: + spinner("Fetching feeds...", 1) + typer.secho("Filtered feeds are:", fg=typer.colors.BLUE) + if pretty is True: + output = ColorTable( + ["Feed Source", "Transit URL", "Bounding Box"], theme=Themes.OCEAN, hrules=1 + ) + else: + output = [] + + for src in feed_sources: + feed_bbox: Bbox = src.bbox + if bbox is not None and predicate == "contains": + if not bbox_contains_bbox(feed_bbox, bbox): + continue + elif bbox is not None and predicate == "intersects": + if (not bbox_intersects_bbox(feed_bbox, bbox)) and ( + not bbox_intersects_bbox(bbox, feed_bbox) + ): + continue + + if pretty is True: + row = [ + src.__name__, + src.url, + [feed_bbox.min_x, feed_bbox.min_y, feed_bbox.max_x, feed_bbox.max_y], + ] + output.add_row(row) + else: + output.append(src.url) + + if pretty is True: + print("\n" + output.get_string()) + else: + print("\n".join(output)) + typer.secho("All done!", fg=typer.colors.GREEN) @app.command() @@ -95,7 +124,7 @@ def fetch_feeds(sources=None): # make a copy of the list of all modules in feed_sources; # default to use all of them if not sources: - sources = list(feed_sources.__all__) + sources = list(feed_sources) LOG.info("Going to fetch feeds from sources: %s", sources) for src in sources: diff --git a/gtfs/feed_sources/__init__.py b/gtfs/feed_sources/__init__.py index d984eb3..fb84bf1 100644 --- a/gtfs/feed_sources/__init__.py +++ b/gtfs/feed_sources/__init__.py @@ -3,4 +3,4 @@ from .AlbanyNy import AlbanyNy from .Berlin import Berlin -__all__ = [Berlin, AlbanyNy] +feed_sources = [Berlin, AlbanyNy] diff --git a/gtfs/utils/check_status.py b/gtfs/utils/check_status.py old mode 100755 new mode 100644 diff --git a/gtfs/utils/constants.py b/gtfs/utils/constants.py index 9ae0fdb..e4e62e0 100644 --- a/gtfs/utils/constants.py +++ b/gtfs/utils/constants.py @@ -1,9 +1,7 @@ import time from enum import Enum -from rich.console import Console from rich.progress import Progress, SpinnerColumn, TextColumn -from rich.style import Style class Predicate(str, Enum): @@ -11,11 +9,6 @@ class Predicate(str, Enum): contains = "contains" -error = Style(color="red", bold=True) -success = Style(color="green", bold=True) -console = Console() - - def spinner(text: str, timer: int) -> None: with Progress( SpinnerColumn(), diff --git a/gtfs/utils/extend_effective_dates.py b/gtfs/utils/extend_effective_dates.py old mode 100755 new mode 100644 diff --git a/pyproject.toml b/pyproject.toml index 1007582..b2a993f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ skip = [ ] [tool.mypy] +ignore_missing_imports = true # allow imports of non-typed libraries strict_optional = true # no implicit Optional[Any] for None warn_return_any = true # if return type is any for a function warn_redundant_casts = true # if a cast to a type is unnecessary diff --git a/tests/test_cli.py b/tests/test_cli.py index cf67350..885793b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -31,13 +31,20 @@ def test_bad_args_3(self, runner): assert "Area cannot be zero!" in result.stdout def test_intersects_predicate(self, runner): - result = runner.invoke(app, ["list-feeds", "--bbox", "6.626953,49.423342,23.348144,54.265953"]) + result = runner.invoke( + app, ["list-feeds", "-pd", "intersects", "--bbox", "6.626953,49.423342,23.348144,54.265953"] + ) assert result.exit_code == 0 - assert "Feeds based on bbox input" in result.stdout + assert "Filtered feeds are:" in result.stdout def test_contains_predicate(self, runner): result = runner.invoke( - app, ["list-feeds", "-p", "contains", "--bbox", "6.626953,49.423342,23.348144,54.265953"] + app, ["list-feeds", "-pd", "contains", "--bbox", "6.626953,49.423342,23.348144,54.265953"] ) assert result.exit_code == 0 - assert "Feeds based on bbox input" in result.stdout + assert "Filtered feeds are:" in result.stdout + + def test_pretty(self, runner): + result = runner.invoke(app, ["list-feeds", "-pt"]) + assert result.exit_code == 0 + assert "Filtered feeds are:" in result.stdout From 57a9053295d17e587c9da5cc718f602fea49de52 Mon Sep 17 00:00:00 2001 From: Ananya Nayak Date: Mon, 5 Jun 2023 22:31:54 +0530 Subject: [PATCH 2/6] some cleanup --- gtfs/__main__.py | 72 +++++++++++++++++++++------------------------ gtfs/feed_source.py | 7 ++++- tests/test_cli.py | 3 -- 3 files changed, 39 insertions(+), 43 deletions(-) diff --git a/gtfs/__main__.py b/gtfs/__main__.py index 5a8de4f..e1465ad 100755 --- a/gtfs/__main__.py +++ b/gtfs/__main__.py @@ -1,7 +1,7 @@ #!/usr/bin/env python """Command line interface for fetching GTFS.""" import logging -from typing import Union +from typing import Optional import typer from prettytable.colortable import ColorTable, Themes @@ -17,7 +17,7 @@ app = typer.Typer() -def check_bbox(bbox: str) -> Union[Bbox, None]: +def check_bbox(bbox: str) -> Optional[Bbox]: if bbox is None: return try: @@ -42,7 +42,7 @@ def check_bbox(bbox: str) -> Union[Bbox, None]: @app.command() def list_feeds( bbox: Annotated[ - Union[str, None], + Optional[str], typer.Option( "--bbox", "-b", @@ -51,7 +51,7 @@ def list_feeds( ), ] = None, predicate: Annotated[ - Union[Predicate, None], + Optional[Predicate], typer.Option( "--predicate", "-pd", @@ -78,13 +78,10 @@ def list_feeds( ) else: spinner("Fetching feeds...", 1) - typer.secho("Filtered feeds are:", fg=typer.colors.BLUE) if pretty is True: - output = ColorTable( + pretty_output = ColorTable( ["Feed Source", "Transit URL", "Bounding Box"], theme=Themes.OCEAN, hrules=1 ) - else: - output = [] for src in feed_sources: feed_bbox: Bbox = src.bbox @@ -98,20 +95,19 @@ def list_feeds( continue if pretty is True: - row = [ - src.__name__, - src.url, - [feed_bbox.min_x, feed_bbox.min_y, feed_bbox.max_x, feed_bbox.max_y], - ] - output.add_row(row) - else: - output.append(src.url) + pretty_output.add_row( + [ + src.__name__, + src.url, + [feed_bbox.min_x, feed_bbox.min_y, feed_bbox.max_x, feed_bbox.max_y], + ] + ) + continue + + print(src.url) if pretty is True: - print("\n" + output.get_string()) - else: - print("\n".join(output)) - typer.secho("All done!", fg=typer.colors.GREEN) + print("\n" + pretty_output.get_string()) @app.command() @@ -121,26 +117,22 @@ def fetch_feeds(sources=None): """ statuses = {} # collect the statuses for all the files - # make a copy of the list of all modules in feed_sources; # default to use all of them if not sources: - sources = list(feed_sources) + sources = feed_sources LOG.info("Going to fetch feeds from sources: %s", sources) for src in sources: LOG.debug("Going to start fetch for %s...", src) try: - module = getattr(feed_sources, src) - # expect a class with the same name as the module; instantiate and fetch its feeds - klass = getattr(module, src) - if issubclass(klass, FeedSource): - inst = klass() + if issubclass(src, FeedSource): + inst = src() inst.fetch() statuses.update(inst.status) else: - LOG.warn( + LOG.warning( "Skipping class %s, which does not subclass FeedSource.", - klass.__name__, + src.__name__, ) except AttributeError: LOG.error("Skipping feed %s, which could not be found.", src) @@ -149,8 +141,18 @@ def fetch_feeds(sources=None): if "last_check" in statuses: del statuses["last_check"] - # display results - ptable = ColorTable() + ptable = ColorTable( + [ + "file", + "new?", + "valid?", + "current?", + "newly effective?", + "error", + ], + theme=Themes.OCEAN, + hrules=1, + ) for file_name in statuses: stat = statuses[file_name] @@ -166,14 +168,6 @@ def fetch_feeds(sources=None): msg.append("") ptable.add_row(msg) - ptable.field_names = [ - "file", - "new?", - "valid?", - "current?", - "newly effective?", - "error", - ] LOG.info("Results:\n%s", ptable.get_string()) LOG.info("All done!") diff --git a/gtfs/feed_source.py b/gtfs/feed_source.py index caca710..04926c3 100644 --- a/gtfs/feed_source.py +++ b/gtfs/feed_source.py @@ -27,4 +27,9 @@ def bbox(self) -> Bbox: pass def fetch(self): - pass + """ + Modify this method in subclass for importing feed(s) from agency. + + By default, loops over given URLs, checks the last-modified header to see if a new + download is available, streams the download if so, and verifies the new GTFS. + """ diff --git a/tests/test_cli.py b/tests/test_cli.py index 885793b..eca1b13 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -35,16 +35,13 @@ def test_intersects_predicate(self, runner): app, ["list-feeds", "-pd", "intersects", "--bbox", "6.626953,49.423342,23.348144,54.265953"] ) assert result.exit_code == 0 - assert "Filtered feeds are:" in result.stdout def test_contains_predicate(self, runner): result = runner.invoke( app, ["list-feeds", "-pd", "contains", "--bbox", "6.626953,49.423342,23.348144,54.265953"] ) assert result.exit_code == 0 - assert "Filtered feeds are:" in result.stdout def test_pretty(self, runner): result = runner.invoke(app, ["list-feeds", "-pt"]) assert result.exit_code == 0 - assert "Filtered feeds are:" in result.stdout From 86acde35190e62253bde8442e8fbf4e56ee1c00e Mon Sep 17 00:00:00 2001 From: Ananya Nayak Date: Wed, 7 Jun 2023 21:05:20 +0530 Subject: [PATCH 3/6] feat: fetch-feeds command --- .gitignore | 53 +----------- gtfs/__main__.py | 150 ++++++++++++++++++++++------------ gtfs/feed_source.py | 144 +++++++++++++++++++++++++++++++- gtfs/feed_sources/AlbanyNy.py | 3 + gtfs/feed_sources/Berlin.py | 3 + gtfs/utils/constants.py | 5 ++ tests/test_cli.py | 12 ++- 7 files changed, 258 insertions(+), 112 deletions(-) diff --git a/.gitignore b/.gitignore index b627ff1..675ca8c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,9 @@ .idea gtfs/__version__.py + # Generated / downloaded files *.zip *.p -*.csv -*.html -patco-gtfs/ -transitfeedcrash.txt # virtualenv .venv/ @@ -18,51 +15,3 @@ transitfeedcrash.txt # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] - -# C extensions -*.so - -# Distribution / packaging -.Python -env/ -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -*.egg-info/ -.installed.cfg -*.egg - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*,cover - -# Translations -*.mo -*.pot - -# PyBuilder -target/ diff --git a/gtfs/__main__.py b/gtfs/__main__.py index e1465ad..8e6c0d1 100755 --- a/gtfs/__main__.py +++ b/gtfs/__main__.py @@ -1,6 +1,6 @@ #!/usr/bin/env python """Command line interface for fetching GTFS.""" -import logging +import os.path from typing import Optional import typer @@ -9,11 +9,9 @@ from .feed_source import FeedSource from .feed_sources import feed_sources -from .utils.constants import Predicate, spinner +from .utils.constants import LOG, Predicate, spinner from .utils.geom import Bbox, bbox_contains_bbox, bbox_intersects_bbox -logging.basicConfig() -LOG = logging.getLogger() app = typer.Typer() @@ -67,7 +65,12 @@ def list_feeds( ), ] = False, ) -> None: - """Filter feeds spatially based on bounding box.""" + """Filter feeds spatially based on bounding box or list all of them. + + :param bbox: set of coordinates to filter feeds spatially + :param predicate: the gtfs feed should intersect or should be contained inside the user's bbox + :param pretty: display feeds inside a pretty table + """ if bbox is None and predicate is not None: raise typer.BadParameter( f"Please pass a bbox if you want to filter feeds spatially based on predicate = {predicate}!" @@ -83,6 +86,8 @@ def list_feeds( ["Feed Source", "Transit URL", "Bounding Box"], theme=Themes.OCEAN, hrules=1 ) + filtered_srcs = "" + for src in feed_sources: feed_bbox: Bbox = src.bbox if bbox is not None and predicate == "contains": @@ -94,6 +99,7 @@ def list_feeds( ): continue + filtered_srcs += src.__name__ + ", " if pretty is True: pretty_output.add_row( [ @@ -109,67 +115,107 @@ def list_feeds( if pretty is True: print("\n" + pretty_output.get_string()) + if typer.confirm("Do you want to fetch feeds from these sources?"): + fetch_feeds(sources=filtered_srcs[:-1]) + @app.command() -def fetch_feeds(sources=None): - """ +def fetch_feeds( + sources: Annotated[ + Optional[str], + typer.Option( + "--sources", + "-src", + help="pass value as a string separated by commas like this: src1,src2,src3", + ), + ] = None, + search: Annotated[ + Optional[str], + typer.Option( + "--search", + "-s", + help="search for feeds based on a string", + ), + ] = None, + output_dir: Annotated[ + Optional[str], + typer.Option( + "--output-dir", + "-o", + help="the directory where the downloaded feeds will be saved", + ), + ] = "feeds", +) -> None: + """Fetch feeds from sources. + :param sources: List of :FeedSource: modules to fetch; if not set, will fetch all available. + :param search: Search for feeds based on a string. + :param output_dir: The directory where the downloaded feeds will be saved; default is feeds. """ - statuses = {} # collect the statuses for all the files + # statuses = {} # collect the statuses for all the files - # default to use all of them if not sources: - sources = feed_sources + if not search: + # fetch all feeds + sources = feed_sources + else: + # fetch feeds based on search + sources = [ + src + for src in feed_sources + if search.lower() in src.__name__.lower() or search.lower() in src.url.lower() + ] + else: + # fetch feeds based on sources + sources = [src for src in feed_sources if src.__name__.lower() in sources.lower()] + + output_dir_path = os.path.join(os.getcwd(), output_dir) + if not os.path.exists(output_dir_path): + os.makedirs(output_dir_path) - LOG.info("Going to fetch feeds from sources: %s", sources) + LOG.info(f"Going to fetch feeds from sources: {sources}") for src in sources: - LOG.debug("Going to start fetch for %s...", src) + LOG.debug(f"Going to start fetch for {src}...") try: if issubclass(src, FeedSource): inst = src() + inst.ddir = output_dir_path + inst.status_file = os.path.join(inst.ddir, src.__name__ + ".pkl") inst.fetch() - statuses.update(inst.status) + # statuses.update(inst.status) else: - LOG.warning( - "Skipping class %s, which does not subclass FeedSource.", - src.__name__, - ) + LOG.warning(f"Skipping class {src.__name__}, which does not subclass FeedSource.") except AttributeError: - LOG.error("Skipping feed %s, which could not be found.", src) - - # remove last check key set at top level of each status dictionary - if "last_check" in statuses: - del statuses["last_check"] - - ptable = ColorTable( - [ - "file", - "new?", - "valid?", - "current?", - "newly effective?", - "error", - ], - theme=Themes.OCEAN, - hrules=1, - ) - - for file_name in statuses: - stat = statuses[file_name] - msg = [] - msg.append(file_name) - msg.append("x" if "is_new" in stat and stat["is_new"] else "") - msg.append("x" if "is_valid" in stat and stat["is_valid"] else "") - msg.append("x" if "is_current" in stat and stat["is_current"] else "") - msg.append("x" if "newly_effective" in stat and stat.get("newly_effective") else "") - if "error" in stat: - msg.append(stat["error"]) - else: - msg.append("") - ptable.add_row(msg) - - LOG.info("Results:\n%s", ptable.get_string()) - LOG.info("All done!") + LOG.error(f"Skipping feed {src}, which could not be found.") + + # ptable = ColorTable( + # [ + # "file", + # "new?", + # "valid?", + # "current?", + # "newly effective?", + # "error", + # ], + # theme=Themes.OCEAN, + # hrules=1, + # ) + # + # for file_name in statuses: + # stat = statuses[file_name] + # msg = [] + # msg.append(file_name) + # msg.append("x" if "is_new" in stat and stat["is_new"] else "") + # msg.append("x" if "is_valid" in stat and stat["is_valid"] else "") + # msg.append("x" if "is_current" in stat and stat["is_current"] else "") + # msg.append("x" if "newly_effective" in stat and stat.get("newly_effective") else "") + # if "error" in stat: + # msg.append(stat["error"]) + # else: + # msg.append("") + # ptable.add_row(msg) + # + # LOG.info("\n" + ptable.get_string()) if __name__ == "__main__": diff --git a/gtfs/feed_source.py b/gtfs/feed_source.py index 04926c3..6c9840c 100644 --- a/gtfs/feed_source.py +++ b/gtfs/feed_source.py @@ -2,10 +2,22 @@ To add a new feed, add a subclass of this to the `feed_sources` directory. """ +import os +import pickle + +# import subprocess +import zipfile from abc import ABC, abstractmethod +from datetime import datetime + +import requests +from gtfs.utils.constants import LOG from gtfs.utils.geom import Bbox +# format time checks like last-modified header +TIMECHECK_FMT = "%a, %d %b %Y %H:%M:%S GMT" + class FeedSource(ABC): """Base class for a GTFS source. Class and module names are expected to match. @@ -16,6 +28,11 @@ class FeedSource(ABC): - override :fetch: method as necessary to fetch feeds for the agency. """ + def __init__(self): + self.status = {} + self.ddir = "" + self.status_file = "" + @property @abstractmethod def url(self) -> str: @@ -26,10 +43,129 @@ def url(self) -> str: def bbox(self) -> Bbox: pass - def fetch(self): - """ - Modify this method in subclass for importing feed(s) from agency. + def write_status(self): + """Write pickled log of feed statuses and last times files were downloaded.""" + LOG.debug(f"Downloading finished. Writing status file {self.status_file}...") + with open(self.status_file, "wb") as status_file: + pickle.dump(self.status, status_file) + LOG.debug(f"Statuses written to {self.status_file}.") + + # TODO - add a method to verify the feed + def fetch(self) -> bool: + """Modify this method in subclass for importing feed(s) from agency. - By default, loops over given URLs, checks the last-modified header to see if a new + By default, checks the last-modified header to see if a new download is available, streams the download if so, and verifies the new GTFS. """ + if self.url: + feed_file = self.__class__.__name__ + if self.download_feed(feed_file, self.url): + self.write_status() + # if self.verify(feed_file): + # LOG.info('GTFS verification succeeded.') + # return True + # else: + # LOG.error('GTFS verification failed.') + # return False + else: + return False + else: + raise ValueError("URL not set for feed source!") + + def check_header_newer(self, feed_file: str, url: str): + """Check if last-modified header indicates a new download is available. + + :param feed_file: Name of downloaded file (relative to :ddir:) + :param url: Where GTFS is downloaded from + :returns: 1 if newer GTFS available; 0 if info missing; -1 if already have most recent + """ + if not os.path.exists(self.status_file): + LOG.debug(f"Status file {self.status_file} not found.") + return 0 + + with open(self.status_file, "rb") as f: + last_status = pickle.load(f) + if feed_file in last_status and "posted_date" in last_status[feed_file]: + last_fetch = last_status[feed_file]["posted_date"] + hdr = requests.head(url) + hdr = hdr.headers + if hdr.get("last-modified"): + last_mod = hdr.get("last-modified") + if last_fetch >= last_mod: + LOG.info(f"No new download available for {feed_file}.") + return -1 + else: + LOG.info(f"New download available for {feed_file}.") + LOG.info(f"Last download from: {last_fetch}.") + LOG.info(f"New download posted: {last_mod}") + return 1 + else: + # should try to find another way to check for new feeds if header not set + LOG.debug(f"No last-modified header set for {feed_file} download link.") + return 0 + else: + LOG.debug(f"Time check entry for {feed_file} not found.") + return 0 + + def download_feed(self, feed_file: str, url: str, do_stream: bool = True) -> bool: + """Download feed. + + :param feed_file: File name to save download as, relative to :ddir: + :param url: Where to download the GTFS from + :param do_stream: If True, stream the download + :returns: True if download was successful + """ + if self.check_header_newer(feed_file, url) == -1: + # Nothing new to fetch; done here + return False + + # feed_file is local to download directory + feed_file_path = os.path.join(self.ddir, feed_file + ".zip") + LOG.info(f"Getting file {feed_file}...from...{url}") + request = requests.get(url, stream=do_stream) + + if request.ok: + with open(feed_file_path, "wb") as download_file: + if do_stream: + for chunk in request.iter_content(chunk_size=1024): + download_file.write(chunk) + else: + download_file.write(request.content) + + info = os.stat(feed_file_path) + if info.st_size < 10000: + # file smaller than 10K; may not be a GTFS + LOG.warning(f"Download for {feed_file_path} is only {str(info.st_size)} bytes.") + if not zipfile.is_zipfile(feed_file_path): + self.set_error(feed_file, "Download is not a zip file") + return False + posted_date = request.headers.get("last-modified") + if not posted_date: + LOG.debug("No last-modified header set") + posted_date = datetime.utcnow().strftime(TIMECHECK_FMT) + self.set_posted_date(feed_file, posted_date) + LOG.info("Download completed successfully.") + return True + else: + self.set_error(feed_file, "Download failed") + return False + + def set_posted_date(self, feed_file: str, posted_date: str): + """Update feed status posted date. Creates new feed status if none found. + + :param feed_file: Name of feed file, relative to :ddir: + :param posted_date: Date string formatted to :TIMECHECK_FMT: when feed was posted + """ + stat = self.status.get(feed_file, {}) + stat["posted_date"] = posted_date + self.status[feed_file] = stat + + def set_error(self, feed_file: str, msg: str): + """If error encountered in processing, set status error message, and unset other fields. + + :param feed_file: Name of feed file, relative to :ddir: + :param msg: Error message to save with status + """ + LOG.error(f"Error processing {feed_file}: {msg}") + self.status[feed_file] = {"error": msg} + self.write_status() diff --git a/gtfs/feed_sources/AlbanyNy.py b/gtfs/feed_sources/AlbanyNy.py index 0df3681..37ce15a 100644 --- a/gtfs/feed_sources/AlbanyNy.py +++ b/gtfs/feed_sources/AlbanyNy.py @@ -8,3 +8,6 @@ class AlbanyNy(FeedSource): url: str = "http://www.cdta.org/schedules/google_transit.zip" bbox: Bbox = Bbox(-74.219321, 42.467161, -73.614608, 43.10706) + + def __init__(self): + super().__init__() diff --git a/gtfs/feed_sources/Berlin.py b/gtfs/feed_sources/Berlin.py index 7e6c63a..7a703d2 100644 --- a/gtfs/feed_sources/Berlin.py +++ b/gtfs/feed_sources/Berlin.py @@ -10,3 +10,6 @@ class Berlin(FeedSource): "https://www.vbb.de/fileadmin/user_upload/VBB/Dokumente/API-Datensaetze/gtfs-mastscharf/GTFS.zip" ) bbox: Bbox = Bbox(10.669821, 50.839245, 17.037088, 54.308626) + + def __init__(self): + super().__init__() diff --git a/gtfs/utils/constants.py b/gtfs/utils/constants.py index e4e62e0..bf4cb6e 100644 --- a/gtfs/utils/constants.py +++ b/gtfs/utils/constants.py @@ -1,3 +1,4 @@ +import logging import time from enum import Enum @@ -17,3 +18,7 @@ def spinner(text: str, timer: int) -> None: ) as progress: progress.add_task(description=text, total=None) time.sleep(timer) + + +logging.basicConfig(level=logging.INFO) +LOG = logging.getLogger() diff --git a/tests/test_cli.py b/tests/test_cli.py index eca1b13..3d02fb2 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -13,7 +13,7 @@ class TestListFeedsCommand: def test_help(self, runner): result = runner.invoke(app, ["list-feeds", "--help"]) assert result.exit_code == 0 - assert "Filter feeds spatially based on bounding box." in result.stdout + assert "Filter feeds spatially based on bounding box or list all of them." in result.stdout def test_bad_args_1(self, runner): result = runner.invoke(app, ["list-feeds", "--bbox", "6.626953,49.423342,23.348144"]) @@ -32,16 +32,20 @@ def test_bad_args_3(self, runner): def test_intersects_predicate(self, runner): result = runner.invoke( - app, ["list-feeds", "-pd", "intersects", "--bbox", "6.626953,49.423342,23.348144,54.265953"] + app, + ["list-feeds", "-pd", "intersects", "--bbox", "6.626953,49.423342,23.348144,54.265953"], + input="N\n", ) assert result.exit_code == 0 def test_contains_predicate(self, runner): result = runner.invoke( - app, ["list-feeds", "-pd", "contains", "--bbox", "6.626953,49.423342,23.348144,54.265953"] + app, + ["list-feeds", "-pd", "contains", "--bbox", "6.626953,49.423342,23.348144,54.265953"], + input="N\n", ) assert result.exit_code == 0 def test_pretty(self, runner): - result = runner.invoke(app, ["list-feeds", "-pt"]) + result = runner.invoke(app, ["list-feeds", "-pt"], input="N\n") assert result.exit_code == 0 From a3c6417c013131820522e3f2f19267ad4b1b833d Mon Sep 17 00:00:00 2001 From: Ananya Nayak Date: Fri, 9 Jun 2023 11:35:59 +0530 Subject: [PATCH 4/6] chore: multithreading for I/O requests --- .gitignore | 2 +- feeds/AlbanyNy.pkl | Bin 0 -> 75 bytes feeds/Berlin.pkl | Bin 0 -> 73 bytes gtfs/__main__.py | 82 +++++++++++++++++++++++++++++++------------- gtfs/feed_source.py | 2 +- tests/test_cli.py | 20 +++++++++++ 6 files changed, 81 insertions(+), 25 deletions(-) create mode 100644 feeds/AlbanyNy.pkl create mode 100644 feeds/Berlin.pkl diff --git a/.gitignore b/.gitignore index 675ca8c..b0441f3 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,7 @@ gtfs/__version__.py # Generated / downloaded files *.zip -*.p +*.pkl # virtualenv .venv/ diff --git a/feeds/AlbanyNy.pkl b/feeds/AlbanyNy.pkl new file mode 100644 index 0000000000000000000000000000000000000000..d23c74e2495b9c1a9411f14955081550ea73a95d GIT binary patch literal 75 zcmZo*nd-m*0kuQ&V&lj13ih6Dt*r X42+Bw49%@ffXG0>-8W=Paj_l%VAU3z literal 0 HcmV?d00001 diff --git a/feeds/Berlin.pkl b/feeds/Berlin.pkl new file mode 100644 index 0000000000000000000000000000000000000000..9b1ea0b5cd464b2592138608f3e78f7245ae8296 GIT binary patch literal 73 zcmZo*nQF%X0ku$mtW1opj13jseM6=c7wZ84FT@q? literal 0 HcmV?d00001 diff --git a/gtfs/__main__.py b/gtfs/__main__.py index 8e4e23a..ad16df3 100755 --- a/gtfs/__main__.py +++ b/gtfs/__main__.py @@ -1,6 +1,7 @@ #!/usr/bin/env python """Command line interface for fetching GTFS.""" import os +import threading from typing import Optional import typer @@ -37,6 +38,18 @@ def check_bbox(bbox: str) -> Optional[Bbox]: return Bbox(min_x, min_y, max_x, max_y) +def check_sources(sources: str) -> Optional[str]: + """Check if the sources are valid.""" + if sources is None: + return None + sources = sources.split(",") + for source in sources: + if not any(src.__name__.lower() == source.lower() for src in feed_sources): + raise typer.BadParameter(f"{source} is not a valid feed source!") + + return ",".join(sources) + + @app.command() def list_feeds( bbox: Annotated[ @@ -86,10 +99,8 @@ def list_feeds( ["Feed Source", "Transit URL", "Bounding Box"], theme=Themes.OCEAN, hrules=1 ) - filtered_srcs = "" - for src in feed_sources: feed_bbox: Bbox = src.bbox if bbox is not None and predicate == "contains": @@ -101,7 +112,6 @@ def list_feeds( ): continue - filtered_srcs += src.__name__ + ", " if pretty is True: @@ -116,7 +126,6 @@ def list_feeds( print(src.url) - if pretty is True: print("\n" + pretty_output.get_string()) @@ -124,7 +133,6 @@ def list_feeds( fetch_feeds(sources=filtered_srcs[:-1]) - @app.command() def fetch_feeds( sources: Annotated[ @@ -132,7 +140,8 @@ def fetch_feeds( typer.Option( "--sources", "-src", - help="pass value as a string separated by commas like this: src1,src2,src3", + help="pass value as a string separated by commas like this: Berlin,AlbanyNy,...", + callback=check_sources, ), ] = None, search: Annotated[ @@ -148,19 +157,27 @@ def fetch_feeds( typer.Option( "--output-dir", "-o", - help="the directory where the downloaded feeds will be saved", + help="the directory where the downloaded feeds will be saved, default is feeds", ), ] = "feeds", + concurrency: Annotated[ + Optional[int], + typer.Option( + "--concurrency", + "-c", + help="the number of concurrent downloads, default is 4", + ), + ] = 4, ) -> None: """Fetch feeds from sources. :param sources: List of :FeedSource: modules to fetch; if not set, will fetch all available. :param search: Search for feeds based on a string. :param output_dir: The directory where the downloaded feeds will be saved; default is feeds. + :param concurrency: The number of concurrent downloads; default is 4. """ # statuses = {} # collect the statuses for all the files - if not sources: if not search: # fetch all feeds @@ -173,8 +190,10 @@ def fetch_feeds( if search.lower() in src.__name__.lower() or search.lower() in src.url.lower() ] else: - # fetch feeds based on sources - sources = [src for src in feed_sources if src.__name__.lower() in sources.lower()] + if search: + raise typer.BadParameter("Please pass either sources or search, not both at the same time!") + else: + sources = [src for src in feed_sources if src.__name__.lower() in sources.lower()] output_dir_path = os.path.join(os.getcwd(), output_dir) if not os.path.exists(output_dir_path): @@ -182,19 +201,36 @@ def fetch_feeds( LOG.info(f"Going to fetch feeds from sources: {sources}") - for src in sources: - LOG.debug(f"Going to start fetch for {src}...") - try: - if issubclass(src, FeedSource): - inst = src() - inst.ddir = output_dir_path - inst.status_file = os.path.join(inst.ddir, src.__name__ + ".pkl") - inst.fetch() - # statuses.update(inst.status) - else: - LOG.warning(f"Skipping class {src.__name__}, which does not subclass FeedSource.") - except AttributeError: - LOG.error(f"Skipping feed {src}, which could not be found.") + threads = [] + + def thread_worker(): + while True: + try: + src = sources.pop(0) + except IndexError: + break + + LOG.debug(f"Going to start fetch for {src}...") + try: + if issubclass(src, FeedSource): + inst = src() + inst.ddir = output_dir_path + inst.status_file = os.path.join(inst.ddir, src.__name__ + ".pkl") + inst.fetch() + # statuses.update(inst.status) + else: + LOG.warning(f"Skipping class {src.__name__}, which does not subclass FeedSource.") + except AttributeError: + LOG.error(f"Skipping feed {src}, which could not be found.") + + for _ in range(concurrency): + thread = threading.Thread(target=thread_worker) + thread.start() + threads.append(thread) + + # Wait for all threads to complete + for thread in threads: + thread.join() # ptable = ColorTable( # [ diff --git a/gtfs/feed_source.py b/gtfs/feed_source.py index 6c9840c..e6f1479 100644 --- a/gtfs/feed_source.py +++ b/gtfs/feed_source.py @@ -144,7 +144,7 @@ def download_feed(self, feed_file: str, url: str, do_stream: bool = True) -> boo LOG.debug("No last-modified header set") posted_date = datetime.utcnow().strftime(TIMECHECK_FMT) self.set_posted_date(feed_file, posted_date) - LOG.info("Download completed successfully.") + LOG.info(f"Download completed successfully for {feed_file}.") return True else: self.set_error(feed_file, "Download failed") diff --git a/tests/test_cli.py b/tests/test_cli.py index 3d02fb2..3428b1b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -49,3 +49,23 @@ def test_contains_predicate(self, runner): def test_pretty(self, runner): result = runner.invoke(app, ["list-feeds", "-pt"], input="N\n") assert result.exit_code == 0 + + +class TestFetchFeedsCommand: + def test_help(self, runner): + result = runner.invoke(app, ["fetch-feeds", "--help"]) + assert result.exit_code == 0 + assert "Fetch feeds from sources." in result.stdout + + def test_bad_args(self, runner): + result = runner.invoke(app, ["fetch-feeds", "-src", "berlin", "-s", "cdta"]) + assert result.exit_code == 2 + assert "Please pass either sources or search" in result.stdout + + def test_fetch_with_sources(self, runner): + result = runner.invoke(app, ["fetch-feeds", "-src", "berlin"]) + assert result.exit_code == 0 + + def test_fetch_with_search(self, runner): + result = runner.invoke(app, ["fetch-feeds", "-s", "cdta"]) + assert result.exit_code == 0 From aa907c73c4bc1e22c544fb6d72db53210fc565b9 Mon Sep 17 00:00:00 2001 From: Ananya Nayak Date: Tue, 13 Jun 2023 20:47:30 +0530 Subject: [PATCH 5/6] fix: search option in list-feeds --- feeds/AlbanyNy.pkl | Bin 75 -> 0 bytes feeds/Berlin.pkl | Bin 73 -> 0 bytes gtfs/__main__.py | 165 +++++++++++++++++----------------------- gtfs/feed_source.py | 9 +-- gtfs/utils/constants.py | 3 + gtfs/utils/validator.py | 0 poetry.lock | 2 +- tests/test_cli.py | 22 ++---- 8 files changed, 81 insertions(+), 120 deletions(-) delete mode 100644 feeds/AlbanyNy.pkl delete mode 100644 feeds/Berlin.pkl create mode 100644 gtfs/utils/validator.py diff --git a/feeds/AlbanyNy.pkl b/feeds/AlbanyNy.pkl deleted file mode 100644 index d23c74e2495b9c1a9411f14955081550ea73a95d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 75 zcmZo*nd-m*0kuQ&V&lj13ih6Dt*r X42+Bw49%@ffXG0>-8W=Paj_l%VAU3z diff --git a/feeds/Berlin.pkl b/feeds/Berlin.pkl deleted file mode 100644 index 9b1ea0b5cd464b2592138608f3e78f7245ae8296..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 73 zcmZo*nQF%X0ku$mtW1opj13jseM6=c7wZ84FT@q? diff --git a/gtfs/__main__.py b/gtfs/__main__.py index ad16df3..fcf5c0a 100755 --- a/gtfs/__main__.py +++ b/gtfs/__main__.py @@ -13,7 +13,7 @@ from .utils.constants import LOG, Predicate, spinner from .utils.geom import Bbox, bbox_contains_bbox, bbox_intersects_bbox -app = typer.Typer() +app = typer.Typer(help="Fetch GTFS feeds from various transit agencies.") def check_bbox(bbox: str) -> Optional[Bbox]: @@ -69,6 +69,14 @@ def list_feeds( help="the gtfs feed should intersect or should be contained inside the user's bbox", ), ] = None, + search: Annotated[ + Optional[str], + typer.Option( + "--search", + "-s", + help="search for feeds based on a string", + ), + ] = None, pretty: Annotated[ bool, typer.Option( @@ -78,59 +86,74 @@ def list_feeds( ), ] = False, ) -> None: - """Filter feeds spatially based on bounding box or list all of them. + """Filter feeds spatially based on bounding box or search string. :param bbox: set of coordinates to filter feeds spatially :param predicate: the gtfs feed should intersect or should be contained inside the user's bbox + :param search: Search for feeds based on a string. :param pretty: display feeds inside a pretty table """ - if bbox is None and predicate is not None: - raise typer.BadParameter( - f"Please pass a bbox if you want to filter feeds spatially based on predicate = {predicate}!" - ) - elif bbox is not None and predicate is None: - raise typer.BadParameter( - f"Please pass a predicate if you want to filter feeds spatially based on bbox = {bbox}!" - ) + sources: list = feed_sources + + if search is not None: + if bbox is not None or predicate is not None: + raise typer.BadParameter("Please pass either bbox or search, not both at the same time!") + else: + sources = [ + src + for src in feed_sources + if search.lower() in src.__name__.lower() or search.lower() in src.url.lower() + ] else: - spinner("Fetching feeds...", 1) - if pretty is True: - pretty_output = ColorTable( - ["Feed Source", "Transit URL", "Bounding Box"], theme=Themes.OCEAN, hrules=1 + if bbox is None and predicate is not None: + raise typer.BadParameter( + f"Please pass a bbox if you want to filter feeds spatially based on predicate = {predicate}!" + ) + elif bbox is not None and predicate is None: + raise typer.BadParameter( + f"Please pass a predicate if you want to filter feeds spatially based on bbox = {bbox}!" ) + else: + pass + + spinner("Fetching feeds...", 1) + if pretty is True: + pretty_output = ColorTable( + ["Feed Source", "Transit URL", "Bounding Box"], theme=Themes.OCEAN, hrules=1 + ) + + filtered_srcs: str = "" - filtered_srcs = "" - - for src in feed_sources: - feed_bbox: Bbox = src.bbox - if bbox is not None and predicate == "contains": - if not bbox_contains_bbox(feed_bbox, bbox): - continue - elif bbox is not None and predicate == "intersects": - if (not bbox_intersects_bbox(feed_bbox, bbox)) and ( - not bbox_intersects_bbox(bbox, feed_bbox) - ): - continue - - filtered_srcs += src.__name__ + ", " - - if pretty is True: - pretty_output.add_row( - [ - src.__name__, - src.url, - [feed_bbox.min_x, feed_bbox.min_y, feed_bbox.max_x, feed_bbox.max_y], - ] - ) + for src in sources: + feed_bbox: Bbox = src.bbox + if bbox is not None and predicate == "contains": + if not bbox_contains_bbox(feed_bbox, bbox): + continue + elif bbox is not None and predicate == "intersects": + if (not bbox_intersects_bbox(feed_bbox, bbox)) and ( + not bbox_intersects_bbox(bbox, feed_bbox) + ): continue - print(src.url) + filtered_srcs += src.__name__ + ", " if pretty is True: - print("\n" + pretty_output.get_string()) + pretty_output.add_row( + [ + src.__name__, + src.url, + [feed_bbox.min_x, feed_bbox.min_y, feed_bbox.max_x, feed_bbox.max_y], + ] + ) + continue + + print(src.url) + + if pretty is True: + print("\n" + pretty_output.get_string()) - if typer.confirm("Do you want to fetch feeds from these sources?"): - fetch_feeds(sources=filtered_srcs[:-1]) + if typer.confirm("Do you want to fetch feeds from these sources?"): + fetch_feeds(sources=filtered_srcs[:-1]) @app.command() @@ -144,16 +167,8 @@ def fetch_feeds( callback=check_sources, ), ] = None, - search: Annotated[ - Optional[str], - typer.Option( - "--search", - "-s", - help="search for feeds based on a string", - ), - ] = None, output_dir: Annotated[ - Optional[str], + str, typer.Option( "--output-dir", "-o", @@ -172,28 +187,14 @@ def fetch_feeds( """Fetch feeds from sources. :param sources: List of :FeedSource: modules to fetch; if not set, will fetch all available. - :param search: Search for feeds based on a string. :param output_dir: The directory where the downloaded feeds will be saved; default is feeds. :param concurrency: The number of concurrent downloads; default is 4. """ - # statuses = {} # collect the statuses for all the files if not sources: - if not search: - # fetch all feeds - sources = feed_sources - else: - # fetch feeds based on search - sources = [ - src - for src in feed_sources - if search.lower() in src.__name__.lower() or search.lower() in src.url.lower() - ] + sources = feed_sources else: - if search: - raise typer.BadParameter("Please pass either sources or search, not both at the same time!") - else: - sources = [src for src in feed_sources if src.__name__.lower() in sources.lower()] + sources = [src for src in feed_sources if src.__name__.lower() in sources.lower()] output_dir_path = os.path.join(os.getcwd(), output_dir) if not os.path.exists(output_dir_path): @@ -201,7 +202,7 @@ def fetch_feeds( LOG.info(f"Going to fetch feeds from sources: {sources}") - threads = [] + threads: list[threading.Thread] = [] def thread_worker(): while True: @@ -217,7 +218,6 @@ def thread_worker(): inst.ddir = output_dir_path inst.status_file = os.path.join(inst.ddir, src.__name__ + ".pkl") inst.fetch() - # statuses.update(inst.status) else: LOG.warning(f"Skipping class {src.__name__}, which does not subclass FeedSource.") except AttributeError: @@ -232,35 +232,6 @@ def thread_worker(): for thread in threads: thread.join() - # ptable = ColorTable( - # [ - # "file", - # "new?", - # "valid?", - # "current?", - # "newly effective?", - # "error", - # ], - # theme=Themes.OCEAN, - # hrules=1, - # ) - # - # for file_name in statuses: - # stat = statuses[file_name] - # msg = [] - # msg.append(file_name) - # msg.append("x" if "is_new" in stat and stat["is_new"] else "") - # msg.append("x" if "is_valid" in stat and stat["is_valid"] else "") - # msg.append("x" if "is_current" in stat and stat["is_current"] else "") - # msg.append("x" if "newly_effective" in stat and stat.get("newly_effective") else "") - # if "error" in stat: - # msg.append(stat["error"]) - # else: - # msg.append("") - # ptable.add_row(msg) - # - # LOG.info("\n" + ptable.get_string()) - if __name__ == "__main__": app() diff --git a/gtfs/feed_source.py b/gtfs/feed_source.py index e6f1479..f6f2869 100644 --- a/gtfs/feed_source.py +++ b/gtfs/feed_source.py @@ -4,20 +4,15 @@ """ import os import pickle - -# import subprocess import zipfile from abc import ABC, abstractmethod from datetime import datetime import requests -from gtfs.utils.constants import LOG +from gtfs.utils.constants import LOG, TIMECHECK_FMT from gtfs.utils.geom import Bbox -# format time checks like last-modified header -TIMECHECK_FMT = "%a, %d %b %Y %H:%M:%S GMT" - class FeedSource(ABC): """Base class for a GTFS source. Class and module names are expected to match. @@ -50,7 +45,6 @@ def write_status(self): pickle.dump(self.status, status_file) LOG.debug(f"Statuses written to {self.status_file}.") - # TODO - add a method to verify the feed def fetch(self) -> bool: """Modify this method in subclass for importing feed(s) from agency. @@ -119,7 +113,6 @@ def download_feed(self, feed_file: str, url: str, do_stream: bool = True) -> boo # Nothing new to fetch; done here return False - # feed_file is local to download directory feed_file_path = os.path.join(self.ddir, feed_file + ".zip") LOG.info(f"Getting file {feed_file}...from...{url}") request = requests.get(url, stream=do_stream) diff --git a/gtfs/utils/constants.py b/gtfs/utils/constants.py index bf4cb6e..9a1dbea 100644 --- a/gtfs/utils/constants.py +++ b/gtfs/utils/constants.py @@ -22,3 +22,6 @@ def spinner(text: str, timer: int) -> None: logging.basicConfig(level=logging.INFO) LOG = logging.getLogger() + +# format time checks like last-modified header +TIMECHECK_FMT = "%a, %d %b %Y %H:%M:%S GMT" diff --git a/gtfs/utils/validator.py b/gtfs/utils/validator.py new file mode 100644 index 0000000..e69de29 diff --git a/poetry.lock b/poetry.lock index dda3005..7fa424a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. [[package]] name = "certifi" diff --git a/tests/test_cli.py b/tests/test_cli.py index 3428b1b..7438cfd 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -13,7 +13,7 @@ class TestListFeedsCommand: def test_help(self, runner): result = runner.invoke(app, ["list-feeds", "--help"]) assert result.exit_code == 0 - assert "Filter feeds spatially based on bounding box or list all of them." in result.stdout + assert "Filter feeds spatially based on bounding box or search string." in result.stdout def test_bad_args_1(self, runner): result = runner.invoke(app, ["list-feeds", "--bbox", "6.626953,49.423342,23.348144"]) @@ -30,6 +30,13 @@ def test_bad_args_3(self, runner): assert result.exit_code == 2 assert "Area cannot be zero!" in result.stdout + def test_bad_args_4(self, runner): + result = runner.invoke( + app, ["list-feeds", "--bbox", "6.626953,49.423342,23.348144,54.265953", "--search", "cdta"] + ) + assert result.exit_code == 2 + assert "Please pass either bbox or search" in result.stdout + def test_intersects_predicate(self, runner): result = runner.invoke( app, @@ -56,16 +63,3 @@ def test_help(self, runner): result = runner.invoke(app, ["fetch-feeds", "--help"]) assert result.exit_code == 0 assert "Fetch feeds from sources." in result.stdout - - def test_bad_args(self, runner): - result = runner.invoke(app, ["fetch-feeds", "-src", "berlin", "-s", "cdta"]) - assert result.exit_code == 2 - assert "Please pass either sources or search" in result.stdout - - def test_fetch_with_sources(self, runner): - result = runner.invoke(app, ["fetch-feeds", "-src", "berlin"]) - assert result.exit_code == 0 - - def test_fetch_with_search(self, runner): - result = runner.invoke(app, ["fetch-feeds", "-s", "cdta"]) - assert result.exit_code == 0 From ae2fdecdffc82dd9c5a951855dfffadcc8494e47 Mon Sep 17 00:00:00 2001 From: Ananya Nayak Date: Fri, 16 Jun 2023 18:37:19 +0530 Subject: [PATCH 6/6] chore: some refactoring --- gtfs/__main__.py | 114 ++++++----------------------------- gtfs/feed_source.py | 42 ++++++------- gtfs/utils/check_params.py | 58 ++++++++++++++++++ gtfs/utils/constants.py | 6 ++ gtfs/utils/multithreading.py | 38 ++++++++++++ gtfs/utils/validator.py | 0 6 files changed, 140 insertions(+), 118 deletions(-) create mode 100644 gtfs/utils/check_params.py create mode 100644 gtfs/utils/multithreading.py delete mode 100644 gtfs/utils/validator.py diff --git a/gtfs/__main__.py b/gtfs/__main__.py index fcf5c0a..78bcb4e 100755 --- a/gtfs/__main__.py +++ b/gtfs/__main__.py @@ -1,55 +1,20 @@ #!/usr/bin/env python """Command line interface for fetching GTFS.""" -import os -import threading from typing import Optional import typer from prettytable.colortable import ColorTable, Themes from typing_extensions import Annotated -from .feed_source import FeedSource from .feed_sources import feed_sources +from .utils.check_params import check_bbox, check_output_dir, check_sources from .utils.constants import LOG, Predicate, spinner from .utils.geom import Bbox, bbox_contains_bbox, bbox_intersects_bbox +from .utils.multithreading import multi_fetch app = typer.Typer(help="Fetch GTFS feeds from various transit agencies.") -def check_bbox(bbox: str) -> Optional[Bbox]: - if bbox is None: - return - try: - min_x, min_y, max_x, max_y = [float(coord) for coord in bbox.split(",")] - except ValueError as e: - err_message = e.args[0] - if "could not convert" in err_message: - raise typer.BadParameter("Please pass only numbers as bbox values!") - elif "not enough values to unpack" in err_message: - raise typer.BadParameter( - "Please pass bbox as a string separated by commas like this: min_x,min_y,max_x,max_y" - ) - else: - raise typer.BadParameter(f"Unhandled exception: {e}") - - if min_x == max_x or min_y == max_y: - raise typer.BadParameter("Area cannot be zero! Please pass a valid bbox.") - - return Bbox(min_x, min_y, max_x, max_y) - - -def check_sources(sources: str) -> Optional[str]: - """Check if the sources are valid.""" - if sources is None: - return None - sources = sources.split(",") - for source in sources: - if not any(src.__name__.lower() == source.lower() for src in feed_sources): - raise typer.BadParameter(f"{source} is not a valid feed source!") - - return ",".join(sources) - - @app.command() def list_feeds( bbox: Annotated[ @@ -97,33 +62,29 @@ def list_feeds( if search is not None: if bbox is not None or predicate is not None: - raise typer.BadParameter("Please pass either bbox or search, not both at the same time!") - else: - sources = [ - src - for src in feed_sources - if search.lower() in src.__name__.lower() or search.lower() in src.url.lower() - ] - else: - if bbox is None and predicate is not None: raise typer.BadParameter( - f"Please pass a bbox if you want to filter feeds spatially based on predicate = {predicate}!" - ) - elif bbox is not None and predicate is None: - raise typer.BadParameter( - f"Please pass a predicate if you want to filter feeds spatially based on bbox = {bbox}!" + "Please pass either bbox or search text, not both at the same time!" ) else: - pass + sources = [src for src in feed_sources if search.lower() in src.__name__.lower()] + + if bbox is None and predicate is not None: + raise typer.BadParameter( + f"Please pass a bbox if you want to filter feeds spatially based on predicate = {predicate}!" + ) + + if bbox is not None and predicate is None: + raise typer.BadParameter( + f"Please pass a predicate if you want to filter feeds spatially based on bbox = {bbox}!" + ) + + spinner("Filtering feeds...", 1) - spinner("Fetching feeds...", 1) if pretty is True: pretty_output = ColorTable( ["Feed Source", "Transit URL", "Bounding Box"], theme=Themes.OCEAN, hrules=1 ) - filtered_srcs: str = "" - for src in sources: feed_bbox: Bbox = src.bbox if bbox is not None and predicate == "contains": @@ -135,8 +96,6 @@ def list_feeds( ): continue - filtered_srcs += src.__name__ + ", " - if pretty is True: pretty_output.add_row( [ @@ -152,9 +111,6 @@ def list_feeds( if pretty is True: print("\n" + pretty_output.get_string()) - if typer.confirm("Do you want to fetch feeds from these sources?"): - fetch_feeds(sources=filtered_srcs[:-1]) - @app.command() def fetch_feeds( @@ -173,6 +129,7 @@ def fetch_feeds( "--output-dir", "-o", help="the directory where the downloaded feeds will be saved, default is feeds", + callback=check_output_dir, ), ] = "feeds", concurrency: Annotated[ @@ -187,7 +144,8 @@ def fetch_feeds( """Fetch feeds from sources. :param sources: List of :FeedSource: modules to fetch; if not set, will fetch all available. - :param output_dir: The directory where the downloaded feeds will be saved; default is feeds. + :param output_dir: The directory where the downloaded feeds will be saved; default is "feeds" + in current working directory. :param concurrency: The number of concurrent downloads; default is 4. """ @@ -196,41 +154,9 @@ def fetch_feeds( else: sources = [src for src in feed_sources if src.__name__.lower() in sources.lower()] - output_dir_path = os.path.join(os.getcwd(), output_dir) - if not os.path.exists(output_dir_path): - os.makedirs(output_dir_path) - LOG.info(f"Going to fetch feeds from sources: {sources}") - threads: list[threading.Thread] = [] - - def thread_worker(): - while True: - try: - src = sources.pop(0) - except IndexError: - break - - LOG.debug(f"Going to start fetch for {src}...") - try: - if issubclass(src, FeedSource): - inst = src() - inst.ddir = output_dir_path - inst.status_file = os.path.join(inst.ddir, src.__name__ + ".pkl") - inst.fetch() - else: - LOG.warning(f"Skipping class {src.__name__}, which does not subclass FeedSource.") - except AttributeError: - LOG.error(f"Skipping feed {src}, which could not be found.") - - for _ in range(concurrency): - thread = threading.Thread(target=thread_worker) - thread.start() - threads.append(thread) - - # Wait for all threads to complete - for thread in threads: - thread.join() + multi_fetch(sources, output_dir, concurrency) if __name__ == "__main__": diff --git a/gtfs/feed_source.py b/gtfs/feed_source.py index f6f2869..8e214ed 100644 --- a/gtfs/feed_source.py +++ b/gtfs/feed_source.py @@ -10,7 +10,7 @@ import requests -from gtfs.utils.constants import LOG, TIMECHECK_FMT +from gtfs.utils.constants import LOG, TIMECHECK_FMT, Feed from gtfs.utils.geom import Bbox @@ -38,7 +38,7 @@ def url(self) -> str: def bbox(self) -> Bbox: pass - def write_status(self): + def write_status(self) -> None: """Write pickled log of feed statuses and last times files were downloaded.""" LOG.debug(f"Downloading finished. Writing status file {self.status_file}...") with open(self.status_file, "wb") as status_file: @@ -51,22 +51,16 @@ def fetch(self) -> bool: By default, checks the last-modified header to see if a new download is available, streams the download if so, and verifies the new GTFS. """ - if self.url: - feed_file = self.__class__.__name__ - if self.download_feed(feed_file, self.url): - self.write_status() - # if self.verify(feed_file): - # LOG.info('GTFS verification succeeded.') - # return True - # else: - # LOG.error('GTFS verification failed.') - # return False - else: - return False - else: + if not self.url: raise ValueError("URL not set for feed source!") - def check_header_newer(self, feed_file: str, url: str): + feed_file = self.__class__.__name__ + if not self.download_feed(feed_file, self.url): + return False + + self.write_status() + + def check_header_newer(self, feed_file: str, url: str) -> Feed: """Check if last-modified header indicates a new download is available. :param feed_file: Name of downloaded file (relative to :ddir:) @@ -75,7 +69,7 @@ def check_header_newer(self, feed_file: str, url: str): """ if not os.path.exists(self.status_file): LOG.debug(f"Status file {self.status_file} not found.") - return 0 + return Feed.info_missing with open(self.status_file, "rb") as f: last_status = pickle.load(f) @@ -87,19 +81,19 @@ def check_header_newer(self, feed_file: str, url: str): last_mod = hdr.get("last-modified") if last_fetch >= last_mod: LOG.info(f"No new download available for {feed_file}.") - return -1 + return Feed.new_not_available else: LOG.info(f"New download available for {feed_file}.") LOG.info(f"Last download from: {last_fetch}.") LOG.info(f"New download posted: {last_mod}") - return 1 + return Feed.new_available else: # should try to find another way to check for new feeds if header not set LOG.debug(f"No last-modified header set for {feed_file} download link.") - return 0 + return Feed.info_missing else: LOG.debug(f"Time check entry for {feed_file} not found.") - return 0 + return Feed.info_missing def download_feed(self, feed_file: str, url: str, do_stream: bool = True) -> bool: """Download feed. @@ -109,7 +103,7 @@ def download_feed(self, feed_file: str, url: str, do_stream: bool = True) -> boo :param do_stream: If True, stream the download :returns: True if download was successful """ - if self.check_header_newer(feed_file, url) == -1: + if self.check_header_newer(feed_file, url) == Feed.new_not_available: # Nothing new to fetch; done here return False @@ -143,7 +137,7 @@ def download_feed(self, feed_file: str, url: str, do_stream: bool = True) -> boo self.set_error(feed_file, "Download failed") return False - def set_posted_date(self, feed_file: str, posted_date: str): + def set_posted_date(self, feed_file: str, posted_date: str) -> None: """Update feed status posted date. Creates new feed status if none found. :param feed_file: Name of feed file, relative to :ddir: @@ -153,7 +147,7 @@ def set_posted_date(self, feed_file: str, posted_date: str): stat["posted_date"] = posted_date self.status[feed_file] = stat - def set_error(self, feed_file: str, msg: str): + def set_error(self, feed_file: str, msg: str) -> None: """If error encountered in processing, set status error message, and unset other fields. :param feed_file: Name of feed file, relative to :ddir: diff --git a/gtfs/utils/check_params.py b/gtfs/utils/check_params.py new file mode 100644 index 0000000..9fcd7b1 --- /dev/null +++ b/gtfs/utils/check_params.py @@ -0,0 +1,58 @@ +import pathlib +from typing import Optional + +import typer + +from ..feed_sources import feed_sources +from ..utils.constants import LOG +from ..utils.geom import Bbox + + +def check_bbox(bbox: str) -> Optional[Bbox]: + if bbox is None: + return + try: + min_x, min_y, max_x, max_y = [float(coord) for coord in bbox.split(",")] + except ValueError as e: + err_message = e.args[0] + if "could not convert" in err_message: + raise typer.BadParameter("Please pass only numbers as bbox values!") + elif "not enough values to unpack" in err_message: + raise typer.BadParameter( + "Please pass bbox as a string separated by commas like this: min_x,min_y,max_x,max_y" + ) + else: + raise typer.BadParameter(f"Unhandled exception: {e}") + + if min_x == max_x or min_y == max_y: + raise typer.BadParameter("Area cannot be zero! Please pass a valid bbox.") + + return Bbox(min_x, min_y, max_x, max_y) + + +def check_sources(sources: str) -> Optional[str]: + """Check if the sources are valid.""" + if sources is None: + return + sources = sources.split(",") + for source in sources: + if not any(src.__name__.lower() == source.lower() for src in feed_sources): + raise typer.BadParameter(f"{source} is not a valid feed source!") + + return ",".join(sources) + + +def check_output_dir(output_dir: str) -> pathlib.Path: + """Check if the output directory is valid.""" + path = pathlib.Path.cwd() / "feeds" + + if output_dir != "feeds": + path = pathlib.Path.cwd() / output_dir[1:] if output_dir.startswith("/") else output_dir + + if pathlib.Path.exists(path): + LOG.info(f"Output directory {path} already exists.") + else: + LOG.info(f"Output directory {path} does not exist, will create it.") + pathlib.Path.mkdir(path, parents=True) + + return path diff --git a/gtfs/utils/constants.py b/gtfs/utils/constants.py index 9a1dbea..0ba5268 100644 --- a/gtfs/utils/constants.py +++ b/gtfs/utils/constants.py @@ -10,6 +10,12 @@ class Predicate(str, Enum): contains = "contains" +class Feed(str, Enum): + new_available = "new_available" + new_not_available = "new_not_available" + info_missing = "info_missing" + + def spinner(text: str, timer: int) -> None: with Progress( SpinnerColumn(), diff --git a/gtfs/utils/multithreading.py b/gtfs/utils/multithreading.py new file mode 100644 index 0000000..7a9b17c --- /dev/null +++ b/gtfs/utils/multithreading.py @@ -0,0 +1,38 @@ +import os +import pathlib +import threading + +from ..feed_source import FeedSource +from ..utils.constants import LOG + + +def multi_fetch(sources: list, output_dir_path: pathlib.Path, concurrency: int) -> None: + threads: list[threading.Thread] = [] + + def thread_worker(): + while True: + try: + src = sources.pop(0) + except IndexError: + break + + LOG.debug(f"Going to start fetch for {src}...") + try: + if issubclass(src, FeedSource): + inst = src() + inst.ddir = output_dir_path + inst.status_file = os.path.join(inst.ddir, src.__name__ + ".pkl") + inst.fetch() + else: + LOG.warning(f"Skipping class {src.__name__}, which does not subclass FeedSource.") + except AttributeError: + LOG.error(f"Skipping feed {src}, which could not be found.") + + for _ in range(concurrency): + thread = threading.Thread(target=thread_worker) + thread.start() + threads.append(thread) + + # Wait for all threads to complete + for thread in threads: + thread.join() diff --git a/gtfs/utils/validator.py b/gtfs/utils/validator.py deleted file mode 100644 index e69de29..0000000