diff --git a/.gitignore b/.gitignore index b627ff1..b0441f3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,9 @@ .idea gtfs/__version__.py + # Generated / downloaded files *.zip -*.p -*.csv -*.html -patco-gtfs/ -transitfeedcrash.txt +*.pkl # virtualenv .venv/ @@ -18,51 +15,3 @@ transitfeedcrash.txt # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] - -# C extensions -*.so - -# Distribution / packaging -.Python -env/ -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -*.egg-info/ -.installed.cfg -*.egg - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*,cover - -# Translations -*.mo -*.pot - -# PyBuilder -target/ diff --git a/gtfs/__main__.py b/gtfs/__main__.py index e1465ad..78bcb4e 100755 --- a/gtfs/__main__.py +++ b/gtfs/__main__.py @@ -1,42 +1,18 @@ #!/usr/bin/env python """Command line interface for fetching GTFS.""" -import logging from typing import Optional import typer from prettytable.colortable import ColorTable, Themes from typing_extensions import Annotated -from .feed_source import FeedSource from .feed_sources import feed_sources -from .utils.constants import Predicate, spinner +from .utils.check_params import check_bbox, check_output_dir, check_sources +from .utils.constants import LOG, Predicate, spinner from .utils.geom import Bbox, bbox_contains_bbox, bbox_intersects_bbox +from .utils.multithreading import multi_fetch -logging.basicConfig() -LOG = logging.getLogger() -app = typer.Typer() - - -def check_bbox(bbox: str) -> Optional[Bbox]: - if bbox is None: - return - try: - min_x, min_y, max_x, max_y = [float(coord) for coord in bbox.split(",")] - except ValueError as e: - err_message = e.args[0] - if "could not convert" in err_message: - raise typer.BadParameter("Please pass only numbers as bbox values!") - elif "not enough values to unpack" in err_message: - raise typer.BadParameter( - "Please pass bbox as a string separated by commas like this: min_x,min_y,max_x,max_y" - ) - else: - raise typer.BadParameter(f"Unhandled exception: {e}") - - if min_x == max_x or min_y == max_y: - raise typer.BadParameter("Area cannot be zero! Please pass a valid bbox.") - - return Bbox(min_x, min_y, max_x, max_y) +app = typer.Typer(help="Fetch GTFS feeds from various transit agencies.") @app.command() @@ -58,6 +34,14 @@ def list_feeds( help="the gtfs feed should intersect or should be contained inside the user's bbox", ), ] = None, + search: Annotated[ + Optional[str], + typer.Option( + "--search", + "-s", + help="search for feeds based on a string", + ), + ] = None, pretty: Annotated[ bool, typer.Option( @@ -67,109 +51,112 @@ def list_feeds( ), ] = False, ) -> None: - """Filter feeds spatially based on bounding box.""" + """Filter feeds spatially based on bounding box or search string. + + :param bbox: set of coordinates to filter feeds spatially + :param predicate: the gtfs feed should intersect or should be contained inside the user's bbox + :param search: Search for feeds based on a string. + :param pretty: display feeds inside a pretty table + """ + sources: list = feed_sources + + if search is not None: + if bbox is not None or predicate is not None: + raise typer.BadParameter( + "Please pass either bbox or search text, not both at the same time!" + ) + else: + sources = [src for src in feed_sources if search.lower() in src.__name__.lower()] + if bbox is None and predicate is not None: raise typer.BadParameter( f"Please pass a bbox if you want to filter feeds spatially based on predicate = {predicate}!" ) - elif bbox is not None and predicate is None: + + if bbox is not None and predicate is None: raise typer.BadParameter( f"Please pass a predicate if you want to filter feeds spatially based on bbox = {bbox}!" ) - else: - spinner("Fetching feeds...", 1) - if pretty is True: - pretty_output = ColorTable( - ["Feed Source", "Transit URL", "Bounding Box"], theme=Themes.OCEAN, hrules=1 - ) - for src in feed_sources: - feed_bbox: Bbox = src.bbox - if bbox is not None and predicate == "contains": - if not bbox_contains_bbox(feed_bbox, bbox): - continue - elif bbox is not None and predicate == "intersects": - if (not bbox_intersects_bbox(feed_bbox, bbox)) and ( - not bbox_intersects_bbox(bbox, feed_bbox) - ): - continue - - if pretty is True: - pretty_output.add_row( - [ - src.__name__, - src.url, - [feed_bbox.min_x, feed_bbox.min_y, feed_bbox.max_x, feed_bbox.max_y], - ] - ) - continue + spinner("Filtering feeds...", 1) - print(src.url) + if pretty is True: + pretty_output = ColorTable( + ["Feed Source", "Transit URL", "Bounding Box"], theme=Themes.OCEAN, hrules=1 + ) + + for src in sources: + feed_bbox: Bbox = src.bbox + if bbox is not None and predicate == "contains": + if not bbox_contains_bbox(feed_bbox, bbox): + continue + elif bbox is not None and predicate == "intersects": + if (not bbox_intersects_bbox(feed_bbox, bbox)) and ( + not bbox_intersects_bbox(bbox, feed_bbox) + ): + continue if pretty is True: - print("\n" + pretty_output.get_string()) + pretty_output.add_row( + [ + src.__name__, + src.url, + [feed_bbox.min_x, feed_bbox.min_y, feed_bbox.max_x, feed_bbox.max_y], + ] + ) + continue + + print(src.url) + + if pretty is True: + print("\n" + pretty_output.get_string()) @app.command() -def fetch_feeds(sources=None): - """ +def fetch_feeds( + sources: Annotated[ + Optional[str], + typer.Option( + "--sources", + "-src", + help="pass value as a string separated by commas like this: Berlin,AlbanyNy,...", + callback=check_sources, + ), + ] = None, + output_dir: Annotated[ + str, + typer.Option( + "--output-dir", + "-o", + help="the directory where the downloaded feeds will be saved, default is feeds", + callback=check_output_dir, + ), + ] = "feeds", + concurrency: Annotated[ + Optional[int], + typer.Option( + "--concurrency", + "-c", + help="the number of concurrent downloads, default is 4", + ), + ] = 4, +) -> None: + """Fetch feeds from sources. + :param sources: List of :FeedSource: modules to fetch; if not set, will fetch all available. + :param output_dir: The directory where the downloaded feeds will be saved; default is "feeds" + in current working directory. + :param concurrency: The number of concurrent downloads; default is 4. """ - statuses = {} # collect the statuses for all the files - # default to use all of them if not sources: sources = feed_sources + else: + sources = [src for src in feed_sources if src.__name__.lower() in sources.lower()] - LOG.info("Going to fetch feeds from sources: %s", sources) - for src in sources: - LOG.debug("Going to start fetch for %s...", src) - try: - if issubclass(src, FeedSource): - inst = src() - inst.fetch() - statuses.update(inst.status) - else: - LOG.warning( - "Skipping class %s, which does not subclass FeedSource.", - src.__name__, - ) - except AttributeError: - LOG.error("Skipping feed %s, which could not be found.", src) - - # remove last check key set at top level of each status dictionary - if "last_check" in statuses: - del statuses["last_check"] - - ptable = ColorTable( - [ - "file", - "new?", - "valid?", - "current?", - "newly effective?", - "error", - ], - theme=Themes.OCEAN, - hrules=1, - ) - - for file_name in statuses: - stat = statuses[file_name] - msg = [] - msg.append(file_name) - msg.append("x" if "is_new" in stat and stat["is_new"] else "") - msg.append("x" if "is_valid" in stat and stat["is_valid"] else "") - msg.append("x" if "is_current" in stat and stat["is_current"] else "") - msg.append("x" if "newly_effective" in stat and stat.get("newly_effective") else "") - if "error" in stat: - msg.append(stat["error"]) - else: - msg.append("") - ptable.add_row(msg) + LOG.info(f"Going to fetch feeds from sources: {sources}") - LOG.info("Results:\n%s", ptable.get_string()) - LOG.info("All done!") + multi_fetch(sources, output_dir, concurrency) if __name__ == "__main__": diff --git a/gtfs/feed_source.py b/gtfs/feed_source.py index 04926c3..8e214ed 100644 --- a/gtfs/feed_source.py +++ b/gtfs/feed_source.py @@ -2,8 +2,15 @@ To add a new feed, add a subclass of this to the `feed_sources` directory. """ +import os +import pickle +import zipfile from abc import ABC, abstractmethod +from datetime import datetime +import requests + +from gtfs.utils.constants import LOG, TIMECHECK_FMT, Feed from gtfs.utils.geom import Bbox @@ -16,6 +23,11 @@ class FeedSource(ABC): - override :fetch: method as necessary to fetch feeds for the agency. """ + def __init__(self): + self.status = {} + self.ddir = "" + self.status_file = "" + @property @abstractmethod def url(self) -> str: @@ -26,10 +38,121 @@ def url(self) -> str: def bbox(self) -> Bbox: pass - def fetch(self): - """ - Modify this method in subclass for importing feed(s) from agency. + def write_status(self) -> None: + """Write pickled log of feed statuses and last times files were downloaded.""" + LOG.debug(f"Downloading finished. Writing status file {self.status_file}...") + with open(self.status_file, "wb") as status_file: + pickle.dump(self.status, status_file) + LOG.debug(f"Statuses written to {self.status_file}.") - By default, loops over given URLs, checks the last-modified header to see if a new + def fetch(self) -> bool: + """Modify this method in subclass for importing feed(s) from agency. + + By default, checks the last-modified header to see if a new download is available, streams the download if so, and verifies the new GTFS. """ + if not self.url: + raise ValueError("URL not set for feed source!") + + feed_file = self.__class__.__name__ + if not self.download_feed(feed_file, self.url): + return False + + self.write_status() + + def check_header_newer(self, feed_file: str, url: str) -> Feed: + """Check if last-modified header indicates a new download is available. + + :param feed_file: Name of downloaded file (relative to :ddir:) + :param url: Where GTFS is downloaded from + :returns: 1 if newer GTFS available; 0 if info missing; -1 if already have most recent + """ + if not os.path.exists(self.status_file): + LOG.debug(f"Status file {self.status_file} not found.") + return Feed.info_missing + + with open(self.status_file, "rb") as f: + last_status = pickle.load(f) + if feed_file in last_status and "posted_date" in last_status[feed_file]: + last_fetch = last_status[feed_file]["posted_date"] + hdr = requests.head(url) + hdr = hdr.headers + if hdr.get("last-modified"): + last_mod = hdr.get("last-modified") + if last_fetch >= last_mod: + LOG.info(f"No new download available for {feed_file}.") + return Feed.new_not_available + else: + LOG.info(f"New download available for {feed_file}.") + LOG.info(f"Last download from: {last_fetch}.") + LOG.info(f"New download posted: {last_mod}") + return Feed.new_available + else: + # should try to find another way to check for new feeds if header not set + LOG.debug(f"No last-modified header set for {feed_file} download link.") + return Feed.info_missing + else: + LOG.debug(f"Time check entry for {feed_file} not found.") + return Feed.info_missing + + def download_feed(self, feed_file: str, url: str, do_stream: bool = True) -> bool: + """Download feed. + + :param feed_file: File name to save download as, relative to :ddir: + :param url: Where to download the GTFS from + :param do_stream: If True, stream the download + :returns: True if download was successful + """ + if self.check_header_newer(feed_file, url) == Feed.new_not_available: + # Nothing new to fetch; done here + return False + + feed_file_path = os.path.join(self.ddir, feed_file + ".zip") + LOG.info(f"Getting file {feed_file}...from...{url}") + request = requests.get(url, stream=do_stream) + + if request.ok: + with open(feed_file_path, "wb") as download_file: + if do_stream: + for chunk in request.iter_content(chunk_size=1024): + download_file.write(chunk) + else: + download_file.write(request.content) + + info = os.stat(feed_file_path) + if info.st_size < 10000: + # file smaller than 10K; may not be a GTFS + LOG.warning(f"Download for {feed_file_path} is only {str(info.st_size)} bytes.") + if not zipfile.is_zipfile(feed_file_path): + self.set_error(feed_file, "Download is not a zip file") + return False + posted_date = request.headers.get("last-modified") + if not posted_date: + LOG.debug("No last-modified header set") + posted_date = datetime.utcnow().strftime(TIMECHECK_FMT) + self.set_posted_date(feed_file, posted_date) + LOG.info(f"Download completed successfully for {feed_file}.") + return True + else: + self.set_error(feed_file, "Download failed") + return False + + def set_posted_date(self, feed_file: str, posted_date: str) -> None: + """Update feed status posted date. Creates new feed status if none found. + + :param feed_file: Name of feed file, relative to :ddir: + :param posted_date: Date string formatted to :TIMECHECK_FMT: when feed was posted + """ + stat = self.status.get(feed_file, {}) + stat["posted_date"] = posted_date + self.status[feed_file] = stat + + def set_error(self, feed_file: str, msg: str) -> None: + """If error encountered in processing, set status error message, and unset other fields. + + :param feed_file: Name of feed file, relative to :ddir: + :param msg: Error message to save with status + """ + LOG.error(f"Error processing {feed_file}: {msg}") + self.status[feed_file] = {"error": msg} + self.write_status() diff --git a/gtfs/feed_sources/AlbanyNy.py b/gtfs/feed_sources/AlbanyNy.py index 0df3681..37ce15a 100644 --- a/gtfs/feed_sources/AlbanyNy.py +++ b/gtfs/feed_sources/AlbanyNy.py @@ -8,3 +8,6 @@ class AlbanyNy(FeedSource): url: str = "http://www.cdta.org/schedules/google_transit.zip" bbox: Bbox = Bbox(-74.219321, 42.467161, -73.614608, 43.10706) + + def __init__(self): + super().__init__() diff --git a/gtfs/feed_sources/Berlin.py b/gtfs/feed_sources/Berlin.py index 7e6c63a..7a703d2 100644 --- a/gtfs/feed_sources/Berlin.py +++ b/gtfs/feed_sources/Berlin.py @@ -10,3 +10,6 @@ class Berlin(FeedSource): "https://www.vbb.de/fileadmin/user_upload/VBB/Dokumente/API-Datensaetze/gtfs-mastscharf/GTFS.zip" ) bbox: Bbox = Bbox(10.669821, 50.839245, 17.037088, 54.308626) + + def __init__(self): + super().__init__() diff --git a/gtfs/utils/check_params.py b/gtfs/utils/check_params.py new file mode 100644 index 0000000..9fcd7b1 --- /dev/null +++ b/gtfs/utils/check_params.py @@ -0,0 +1,58 @@ +import pathlib +from typing import Optional + +import typer + +from ..feed_sources import feed_sources +from ..utils.constants import LOG +from ..utils.geom import Bbox + + +def check_bbox(bbox: str) -> Optional[Bbox]: + if bbox is None: + return + try: + min_x, min_y, max_x, max_y = [float(coord) for coord in bbox.split(",")] + except ValueError as e: + err_message = e.args[0] + if "could not convert" in err_message: + raise typer.BadParameter("Please pass only numbers as bbox values!") + elif "not enough values to unpack" in err_message: + raise typer.BadParameter( + "Please pass bbox as a string separated by commas like this: min_x,min_y,max_x,max_y" + ) + else: + raise typer.BadParameter(f"Unhandled exception: {e}") + + if min_x == max_x or min_y == max_y: + raise typer.BadParameter("Area cannot be zero! Please pass a valid bbox.") + + return Bbox(min_x, min_y, max_x, max_y) + + +def check_sources(sources: str) -> Optional[str]: + """Check if the sources are valid.""" + if sources is None: + return + sources = sources.split(",") + for source in sources: + if not any(src.__name__.lower() == source.lower() for src in feed_sources): + raise typer.BadParameter(f"{source} is not a valid feed source!") + + return ",".join(sources) + + +def check_output_dir(output_dir: str) -> pathlib.Path: + """Check if the output directory is valid.""" + path = pathlib.Path.cwd() / "feeds" + + if output_dir != "feeds": + path = pathlib.Path.cwd() / output_dir[1:] if output_dir.startswith("/") else output_dir + + if pathlib.Path.exists(path): + LOG.info(f"Output directory {path} already exists.") + else: + LOG.info(f"Output directory {path} does not exist, will create it.") + pathlib.Path.mkdir(path, parents=True) + + return path diff --git a/gtfs/utils/constants.py b/gtfs/utils/constants.py index e4e62e0..0ba5268 100644 --- a/gtfs/utils/constants.py +++ b/gtfs/utils/constants.py @@ -1,3 +1,4 @@ +import logging import time from enum import Enum @@ -9,6 +10,12 @@ class Predicate(str, Enum): contains = "contains" +class Feed(str, Enum): + new_available = "new_available" + new_not_available = "new_not_available" + info_missing = "info_missing" + + def spinner(text: str, timer: int) -> None: with Progress( SpinnerColumn(), @@ -17,3 +24,10 @@ def spinner(text: str, timer: int) -> None: ) as progress: progress.add_task(description=text, total=None) time.sleep(timer) + + +logging.basicConfig(level=logging.INFO) +LOG = logging.getLogger() + +# format time checks like last-modified header +TIMECHECK_FMT = "%a, %d %b %Y %H:%M:%S GMT" diff --git a/gtfs/utils/multithreading.py b/gtfs/utils/multithreading.py new file mode 100644 index 0000000..7a9b17c --- /dev/null +++ b/gtfs/utils/multithreading.py @@ -0,0 +1,38 @@ +import os +import pathlib +import threading + +from ..feed_source import FeedSource +from ..utils.constants import LOG + + +def multi_fetch(sources: list, output_dir_path: pathlib.Path, concurrency: int) -> None: + threads: list[threading.Thread] = [] + + def thread_worker(): + while True: + try: + src = sources.pop(0) + except IndexError: + break + + LOG.debug(f"Going to start fetch for {src}...") + try: + if issubclass(src, FeedSource): + inst = src() + inst.ddir = output_dir_path + inst.status_file = os.path.join(inst.ddir, src.__name__ + ".pkl") + inst.fetch() + else: + LOG.warning(f"Skipping class {src.__name__}, which does not subclass FeedSource.") + except AttributeError: + LOG.error(f"Skipping feed {src}, which could not be found.") + + for _ in range(concurrency): + thread = threading.Thread(target=thread_worker) + thread.start() + threads.append(thread) + + # Wait for all threads to complete + for thread in threads: + thread.join() diff --git a/poetry.lock b/poetry.lock index dda3005..7fa424a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. [[package]] name = "certifi" diff --git a/tests/test_cli.py b/tests/test_cli.py index eca1b13..7438cfd 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -13,7 +13,7 @@ class TestListFeedsCommand: def test_help(self, runner): result = runner.invoke(app, ["list-feeds", "--help"]) assert result.exit_code == 0 - assert "Filter feeds spatially based on bounding box." in result.stdout + assert "Filter feeds spatially based on bounding box or search string." in result.stdout def test_bad_args_1(self, runner): result = runner.invoke(app, ["list-feeds", "--bbox", "6.626953,49.423342,23.348144"]) @@ -30,18 +30,36 @@ def test_bad_args_3(self, runner): assert result.exit_code == 2 assert "Area cannot be zero!" in result.stdout + def test_bad_args_4(self, runner): + result = runner.invoke( + app, ["list-feeds", "--bbox", "6.626953,49.423342,23.348144,54.265953", "--search", "cdta"] + ) + assert result.exit_code == 2 + assert "Please pass either bbox or search" in result.stdout + def test_intersects_predicate(self, runner): result = runner.invoke( - app, ["list-feeds", "-pd", "intersects", "--bbox", "6.626953,49.423342,23.348144,54.265953"] + app, + ["list-feeds", "-pd", "intersects", "--bbox", "6.626953,49.423342,23.348144,54.265953"], + input="N\n", ) assert result.exit_code == 0 def test_contains_predicate(self, runner): result = runner.invoke( - app, ["list-feeds", "-pd", "contains", "--bbox", "6.626953,49.423342,23.348144,54.265953"] + app, + ["list-feeds", "-pd", "contains", "--bbox", "6.626953,49.423342,23.348144,54.265953"], + input="N\n", ) assert result.exit_code == 0 def test_pretty(self, runner): - result = runner.invoke(app, ["list-feeds", "-pt"]) + result = runner.invoke(app, ["list-feeds", "-pt"], input="N\n") + assert result.exit_code == 0 + + +class TestFetchFeedsCommand: + def test_help(self, runner): + result = runner.invoke(app, ["fetch-feeds", "--help"]) assert result.exit_code == 0 + assert "Fetch feeds from sources." in result.stdout