Skip to content

Commit 757dfc6

Browse files
authored
Add graph path / pathfinder classes (#1803)
This PR adds path / pathfinder classes to support basic graph traversal.
1 parent f00fbc3 commit 757dfc6

File tree

16 files changed

+980
-19
lines changed

16 files changed

+980
-19
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
kind: Under the Hood
2+
body: Add graph path / pathfinder classes
3+
time: 2025-08-01T14:11:59.829244-07:00
4+
custom:
5+
Author: plypaul
6+
Issue: "1803"
Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
from __future__ import annotations
22

33
from abc import ABC
4+
from functools import cached_property
45

6+
from typing_extensions import override
57

6-
class MetricflowGraphLabel(ABC):
8+
from metricflow_semantics.experimental.mf_graph.comparable import Comparable, ComparisonKey
9+
10+
11+
class MetricflowGraphLabel(Comparable, ABC):
712
"""Base class for objects that can be used to lookup nodes / edges in a graph."""
813

9-
pass
14+
@cached_property
15+
@override
16+
def comparison_key(self) -> ComparisonKey:
17+
return ()

metricflow-semantics/metricflow_semantics/experimental/mf_graph/path_finding/__init__.py

Whitespace-only changes.
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
from abc import ABC, abstractmethod
5+
from collections.abc import Sequence, Set
6+
from dataclasses import dataclass
7+
from typing import Generic, Optional, Sized, TypeVar
8+
9+
from typing_extensions import Self, override
10+
11+
from metricflow_semantics.experimental.mf_graph.comparable import Comparable, ComparisonKey
12+
from metricflow_semantics.experimental.mf_graph.mf_graph import MetricflowGraphEdge, MetricflowGraphNode
13+
from metricflow_semantics.mf_logging.pretty_formattable import MetricFlowPrettyFormattable
14+
from metricflow_semantics.mf_logging.pretty_formatter import PrettyFormatContext
15+
from metricflow_semantics.mf_logging.pretty_print import mf_pformat
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
NodeT = TypeVar("NodeT", bound=MetricflowGraphNode)
21+
EdgeT = TypeVar("EdgeT", bound=MetricflowGraphEdge)
22+
23+
24+
class MetricflowGraphPath(Generic[NodeT, EdgeT], Comparable, MetricFlowPrettyFormattable, Sized, ABC):
25+
"""A read-only interface that describes a path in a directed graph."""
26+
27+
@property
28+
@abstractmethod
29+
def edges(self) -> Sequence[EdgeT]:
30+
"""The edges (in order) that constitute this path."""
31+
raise NotImplementedError
32+
33+
@property
34+
@abstractmethod
35+
def nodes(self) -> Sequence[NodeT]:
36+
"""The nodes (in order) that constitute this path."""
37+
raise NotImplementedError
38+
39+
@property
40+
@abstractmethod
41+
def weight(self) -> int:
42+
"""The weight of this path as defined by the weight function that was used during traversal."""
43+
raise NotImplementedError
44+
45+
@override
46+
def pretty_format(self, format_context: PrettyFormatContext) -> Optional[str]:
47+
# return format_context.formatter.pretty_format([node.node_descriptor.node_name for node in self.nodes])
48+
return format_context.formatter.pretty_format_object_by_parts(
49+
class_name=self.__class__.__name__,
50+
field_mapping={
51+
"nodes": [node.node_descriptor.node_name for node in self.nodes],
52+
"weight": self.weight,
53+
},
54+
)
55+
56+
@property
57+
@abstractmethod
58+
def node_set(self) -> Set[NodeT]:
59+
"""The set of nodes in this path. Useful for fast checks for cycles during path extension."""
60+
raise NotImplementedError()
61+
62+
@override
63+
def __len__(self) -> int:
64+
return len(self.nodes)
65+
66+
@abstractmethod
67+
def copy(self) -> Self:
68+
"""Return a shallow copy of this path."""
69+
raise NotImplementedError
70+
71+
def arrow_format(self) -> str:
72+
"""Return a string representation that uses `->` between nodes.
73+
74+
TODO: This is only used in tests so it should be migrated elsewhere.
75+
"""
76+
return f"[weight: {self.weight}] " + " -> ".join([mf_pformat(node) for node in self.nodes])
77+
78+
79+
@dataclass
80+
class MutableGraphPath(MetricflowGraphPath[NodeT, EdgeT], Generic[NodeT, EdgeT]):
81+
"""A path that can be extended and also reverted back to the previous state before extension.
82+
83+
* This is a mutable class as path-finding can traverse many edges, and using a single mutable object reduces
84+
overhead.
85+
* The append / pop functionality is useful for DFS traversal.
86+
* When an edge is added to the path, the incremental weight added by the edge is specified by the caller (this
87+
does not do any weight calculation).
88+
"""
89+
90+
_nodes: list[NodeT]
91+
_edges: list[EdgeT]
92+
_current_weight: int
93+
_current_node_set: set[NodeT]
94+
95+
# As this path is extended step-by-step, keep track of the weights added so that when `pop()` is called, the weight
96+
# of the path afterward can be easily computed by subtracting the last incremental weight added. Similar situation
97+
# for `_node_set_addition_order`.
98+
_weight_addition_order: list[int]
99+
_node_set_addition_order: list[Optional[NodeT]]
100+
101+
@staticmethod
102+
def create(start_node: Optional[NodeT] = None) -> MutableGraphPath: # noqa: D102
103+
path: MutableGraphPath[NodeT, EdgeT] = MutableGraphPath(
104+
_nodes=[],
105+
_edges=[],
106+
_weight_addition_order=[],
107+
_current_weight=0,
108+
_current_node_set=set(),
109+
_node_set_addition_order=[],
110+
)
111+
if start_node:
112+
path._append_node(start_node)
113+
return path
114+
115+
@property
116+
def edges(self) -> Sequence[EdgeT]: # noqa: D102
117+
return self._edges
118+
119+
@property
120+
def nodes(self) -> Sequence[NodeT]: # noqa: D102
121+
return self._nodes
122+
123+
@property
124+
def is_empty(self) -> bool: # noqa: D102
125+
return not self._nodes
126+
127+
def _append_node(self, node: NodeT) -> None:
128+
"""Helper to add a node to the path."""
129+
self._nodes.append(node)
130+
if node in self._current_node_set:
131+
self._node_set_addition_order.append(None)
132+
else:
133+
self._current_node_set.add(node)
134+
self._node_set_addition_order.append(node)
135+
136+
def append_edge(self, edge: EdgeT, weight: int) -> None:
137+
"""Add an edge with the given weight to this path."""
138+
tail_node = edge.tail_node
139+
head_node = edge.head_node
140+
if len(self._nodes) == 0:
141+
self._append_node(tail_node)
142+
self._append_node(head_node)
143+
self._edges.append(edge)
144+
self._weight_addition_order.append(weight)
145+
self._current_weight += weight
146+
147+
@property
148+
def weights(self) -> Sequence[int]:
149+
"""The sequence of weights that were added to compute the total weight of the path."""
150+
return self._weight_addition_order
151+
152+
@property
153+
def weight(self) -> int:
154+
"""The current weight of the path."""
155+
return self._current_weight
156+
157+
def pop_end(self) -> None:
158+
"""Remove the last node / edge added to the path."""
159+
if not self._edges:
160+
if not self._nodes:
161+
raise KeyError("Can't pop an empty path")
162+
self._nodes.pop()
163+
return
164+
165+
self._edges.pop()
166+
self._nodes.pop()
167+
weight = self._weight_addition_order.pop()
168+
self._current_weight -= weight
169+
170+
added_node = self._node_set_addition_order.pop()
171+
if added_node is not None:
172+
self._current_node_set.remove(added_node)
173+
return
174+
175+
@property
176+
def node_set(self) -> Set[NodeT]:
177+
"""A set containing the nodes in this path."""
178+
return self._current_node_set
179+
180+
@override
181+
def copy(self) -> Self:
182+
# noinspection PyArgumentList
183+
return self.__class__(
184+
_nodes=self._nodes.copy(),
185+
_edges=self._edges.copy(),
186+
_current_weight=self._current_weight,
187+
_current_node_set=self._current_node_set.copy(),
188+
_weight_addition_order=self._weight_addition_order.copy(),
189+
_node_set_addition_order=self._node_set_addition_order.copy(),
190+
)
191+
192+
@property
193+
@override
194+
def comparison_key(self) -> ComparisonKey:
195+
return (
196+
tuple(self._nodes),
197+
self._current_weight,
198+
tuple(self._edges),
199+
)
200+
201+
202+
MutablePathT = TypeVar("MutablePathT", bound=MutableGraphPath)

0 commit comments

Comments
 (0)