Skip to content

Commit 80be43a

Browse files
authored
Add hooks to generate_compass_dataset function (#469)
* guard examples with main function; update conversion script * update examples based on latest changes * fix search pruning taking out parents * add hooks to generate_compass_dataset function * make hook parameter field names more verbose and add better doc string
1 parent 9fd966a commit 80be43a

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

python/nrel/routee/compass/compass_app.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from nrel.routee.compass.io.generate_dataset import (
1212
GeneratePipelinePhase,
1313
generate_compass_dataset,
14+
DatasetHook,
1415
)
1516

1617
if TYPE_CHECKING:
@@ -136,6 +137,7 @@ def from_graph(
136137
vehicle_models: Optional[List[str]] = None,
137138
parallelism: Optional[int] = None,
138139
overwrite: bool = False,
140+
hooks: Optional[List[DatasetHook]] = None,
139141
) -> CompassApp:
140142
"""
141143
Build a CompassApp from a networkx graph.
@@ -207,6 +209,7 @@ def from_graph(
207209
raster_resolution_arc_seconds=raster_resolution_arc_seconds,
208210
default_config=True,
209211
vehicle_models=vehicle_models,
212+
hooks=hooks,
210213
)
211214

212215
if not config_path.exists():

python/nrel/routee/compass/io/generate_dataset.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22
import enum
3+
from dataclasses import dataclass
34
from typing import Any, Callable, Dict, List, Optional, Union, TYPE_CHECKING
45
from pathlib import Path
56

@@ -19,6 +20,7 @@
1920
if TYPE_CHECKING:
2021
import networkx
2122
import pandas as pd
23+
import geopandas
2224

2325
log = logging.getLogger(__name__)
2426

@@ -32,6 +34,33 @@
3234
AggFunc = Callable[[Any], Any]
3335

3436

37+
@dataclass
38+
class HookParameters:
39+
"""
40+
Parameters passed to hooks registered with generate_compass_dataset.
41+
42+
These parameters allow developers to access and modify the road network
43+
data before the dataset generation process completes.
44+
45+
Attributes:
46+
output_directory (Path): The directory where the dataset files are being written.
47+
vertices (pd.DataFrame): A DataFrame containing the vertex (node) data
48+
of the road network, including coordinates and IDs.
49+
edges (geopandas.GeoDataFrame): A GeoDataFrame containing the edge (link)
50+
data, including geometries and attributes like distance and speed.
51+
graph (networkx.MultiDiGraph): The processed NetworkX graph object
52+
representing the road network topology and attributes.
53+
"""
54+
55+
output_directory: Path
56+
vertices: pd.DataFrame
57+
edges: geopandas.GeoDataFrame
58+
graph: networkx.MultiDiGraph
59+
60+
61+
DatasetHook = Callable[[HookParameters], None]
62+
63+
3564
class GeneratePipelinePhase(enum.Enum):
3665
GRAPH = 1
3766
CONFIG = 2
@@ -78,6 +107,7 @@ def generate_compass_dataset(
78107
requests_kwds: Optional[Dict[Any, Any]] = None,
79108
afdc_api_key: str = "DEMO_KEY",
80109
vehicle_models: Optional[List[str]] = None,
110+
hooks: Optional[List[DatasetHook]] = None,
81111
) -> None:
82112
"""
83113
Processes a graph downloaded via OSMNx, generating the set of input
@@ -109,6 +139,9 @@ def generate_compass_dataset(
109139
``["2017_CHEVROLET_Bolt", "2016_TOYOTA_Camry_4cyl_2WD"]``).
110140
Use :func:`list_available_vehicle_models` to see valid names.
111141
When ``None`` (the default) all available models are included.
142+
hooks: Optional list of callables that take a ``HookParameters`` object.
143+
These hooks will be called after the dataset has been generated
144+
and before the function returns.
112145
Example:
113146
>>> import osmnx as ox
114147
>>> g = ox.graph_from_place("Denver, Colorado, USA")
@@ -364,6 +397,18 @@ def replace_id(vertex_uuid: pd.Index) -> pd.Series[int]:
364397
compression="gzip",
365398
)
366399

400+
# RUN HOOKS
401+
if hooks is not None:
402+
log.info(f"running {len(hooks)} dataset generation hooks")
403+
params = HookParameters(
404+
output_directory=output_directory,
405+
vertices=v,
406+
edges=e,
407+
graph=g1,
408+
)
409+
for hook in hooks:
410+
hook(params)
411+
367412

368413
def _resolve_required_model_bins(vehicle_models: List[str]) -> set[str]:
369414
"""

0 commit comments

Comments
 (0)