|
1 | 1 | from __future__ import annotations |
2 | 2 | import enum |
| 3 | +from dataclasses import dataclass |
3 | 4 | from typing import Any, Callable, Dict, List, Optional, Union, TYPE_CHECKING |
4 | 5 | from pathlib import Path |
5 | 6 |
|
|
19 | 20 | if TYPE_CHECKING: |
20 | 21 | import networkx |
21 | 22 | import pandas as pd |
| 23 | + import geopandas |
22 | 24 |
|
23 | 25 | log = logging.getLogger(__name__) |
24 | 26 |
|
|
32 | 34 | AggFunc = Callable[[Any], Any] |
33 | 35 |
|
34 | 36 |
|
| 37 | +@dataclass |
| 38 | +class HookParameters: |
| 39 | + """ |
| 40 | + Parameters passed to hooks registered with generate_compass_dataset. |
| 41 | +
|
| 42 | + These parameters allow developers to access and modify the road network |
| 43 | + data before the dataset generation process completes. |
| 44 | +
|
| 45 | + Attributes: |
| 46 | + output_directory (Path): The directory where the dataset files are being written. |
| 47 | + vertices (pd.DataFrame): A DataFrame containing the vertex (node) data |
| 48 | + of the road network, including coordinates and IDs. |
| 49 | + edges (geopandas.GeoDataFrame): A GeoDataFrame containing the edge (link) |
| 50 | + data, including geometries and attributes like distance and speed. |
| 51 | + graph (networkx.MultiDiGraph): The processed NetworkX graph object |
| 52 | + representing the road network topology and attributes. |
| 53 | + """ |
| 54 | + |
| 55 | + output_directory: Path |
| 56 | + vertices: pd.DataFrame |
| 57 | + edges: geopandas.GeoDataFrame |
| 58 | + graph: networkx.MultiDiGraph |
| 59 | + |
| 60 | + |
| 61 | +DatasetHook = Callable[[HookParameters], None] |
| 62 | + |
| 63 | + |
35 | 64 | class GeneratePipelinePhase(enum.Enum): |
36 | 65 | GRAPH = 1 |
37 | 66 | CONFIG = 2 |
@@ -78,6 +107,7 @@ def generate_compass_dataset( |
78 | 107 | requests_kwds: Optional[Dict[Any, Any]] = None, |
79 | 108 | afdc_api_key: str = "DEMO_KEY", |
80 | 109 | vehicle_models: Optional[List[str]] = None, |
| 110 | + hooks: Optional[List[DatasetHook]] = None, |
81 | 111 | ) -> None: |
82 | 112 | """ |
83 | 113 | Processes a graph downloaded via OSMNx, generating the set of input |
@@ -109,6 +139,9 @@ def generate_compass_dataset( |
109 | 139 | ``["2017_CHEVROLET_Bolt", "2016_TOYOTA_Camry_4cyl_2WD"]``). |
110 | 140 | Use :func:`list_available_vehicle_models` to see valid names. |
111 | 141 | When ``None`` (the default) all available models are included. |
| 142 | + hooks: Optional list of callables that take a ``HookParameters`` object. |
| 143 | + These hooks will be called after the dataset has been generated |
| 144 | + and before the function returns. |
112 | 145 | Example: |
113 | 146 | >>> import osmnx as ox |
114 | 147 | >>> g = ox.graph_from_place("Denver, Colorado, USA") |
@@ -364,6 +397,18 @@ def replace_id(vertex_uuid: pd.Index) -> pd.Series[int]: |
364 | 397 | compression="gzip", |
365 | 398 | ) |
366 | 399 |
|
| 400 | + # RUN HOOKS |
| 401 | + if hooks is not None: |
| 402 | + log.info(f"running {len(hooks)} dataset generation hooks") |
| 403 | + params = HookParameters( |
| 404 | + output_directory=output_directory, |
| 405 | + vertices=v, |
| 406 | + edges=e, |
| 407 | + graph=g1, |
| 408 | + ) |
| 409 | + for hook in hooks: |
| 410 | + hook(params) |
| 411 | + |
367 | 412 |
|
368 | 413 | def _resolve_required_model_bins(vehicle_models: List[str]) -> set[str]: |
369 | 414 | """ |
|
0 commit comments