Skip to content

Commit 68a791c

Browse files
committed
Add pathfinder tests.
1 parent 2c321b1 commit 68a791c

File tree

3 files changed

+110
-0
lines changed

3 files changed

+110
-0
lines changed

metricflow-semantics/tests_metricflow_semantics/experimental/mf_graph/flow_graph.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
)
2121
from metricflow_semantics.experimental.mf_graph.mutable_graph import MutableGraph
2222
from metricflow_semantics.experimental.mf_graph.node_descriptor import MetricflowGraphNodeDescriptor
23+
from metricflow_semantics.experimental.mf_graph.path_finding.graph_path import MutableGraphPath
24+
from metricflow_semantics.experimental.mf_graph.path_finding.pathfinder import MetricflowPathfinder
2325
from metricflow_semantics.experimental.ordered_set import MutableOrderedSet
2426
from metricflow_semantics.experimental.singleton import Singleton
2527
from metricflow_semantics.mf_logging.pretty_formattable import MetricFlowPrettyFormattable
@@ -142,3 +144,7 @@ def as_sorted(self) -> FlowGraph:
142144
updated_graph.add_edge(edge)
143145

144146
return updated_graph
147+
148+
149+
FlowGraphPath = MutableGraphPath[FlowNode, FlowEdge]
150+
FlowGraphPathFinder = MetricflowPathfinder[FlowNode, FlowEdge, FlowGraphPath]

metricflow-semantics/tests_metricflow_semantics/experimental/mf_graph/path_finding/__init__.py

Whitespace-only changes.
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
5+
import pytest
6+
from _pytest.fixtures import FixtureRequest
7+
from metricflow_semantics.experimental.dataclass_helpers import fast_frozen_dataclass
8+
from metricflow_semantics.experimental.mf_graph.path_finding.graph_path import MutableGraphPath
9+
from metricflow_semantics.experimental.mf_graph.path_finding.pathfinder import MetricflowPathfinder
10+
from metricflow_semantics.experimental.mf_graph.path_finding.weight_function import EdgeCountWeightFunction
11+
from metricflow_semantics.experimental.ordered_set import FrozenOrderedSet
12+
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration
13+
from metricflow_semantics.test_helpers.snapshot_helpers import assert_object_snapshot_equal
14+
15+
from tests_metricflow_semantics.experimental.mf_graph.flow_graph import (
16+
FlowEdge,
17+
FlowGraph,
18+
FlowGraphPath,
19+
FlowGraphPathFinder,
20+
FlowNode,
21+
IntermediateNode,
22+
SinkNode,
23+
SourceNode,
24+
)
25+
26+
logger = logging.getLogger(__name__)
27+
28+
29+
@fast_frozen_dataclass()
30+
class _PathFinderTestFixture:
31+
request: FixtureRequest
32+
snapshot_configuration: MetricFlowTestConfiguration
33+
graph: FlowGraph
34+
path: MutableGraphPath[FlowNode, FlowEdge]
35+
pathfinder: FlowGraphPathFinder
36+
source_node: FlowNode = SourceNode.get_instance(node_name="source")
37+
sink_node: FlowNode = SinkNode.get_instance(node_name="sink")
38+
a_node: FlowNode = IntermediateNode.get_instance(node_name="a")
39+
b_node: FlowNode = IntermediateNode.get_instance(node_name="b")
40+
41+
42+
@pytest.fixture
43+
def pathfinder_fixture( # noqa: D103
44+
request: FixtureRequest, mf_test_configuration: MetricFlowTestConfiguration, flow_graph: FlowGraph
45+
) -> _PathFinderTestFixture:
46+
return _PathFinderTestFixture(
47+
request=request,
48+
snapshot_configuration=mf_test_configuration,
49+
graph=flow_graph,
50+
pathfinder=MetricflowPathfinder(),
51+
path=MutableGraphPath.create(),
52+
)
53+
54+
55+
def test_find_paths_dfs(pathfinder_fixture: _PathFinderTestFixture) -> None: # noqa: D103
56+
max_weight_to_found_paths: dict[int, list[FlowGraphPath]] = {}
57+
58+
for max_path_weight in range(0, 5):
59+
found_paths: list[FlowGraphPath] = []
60+
for found_path in pathfinder_fixture.pathfinder.find_paths_dfs(
61+
graph=pathfinder_fixture.graph,
62+
initial_path=MutableGraphPath.create(pathfinder_fixture.source_node),
63+
target_nodes={pathfinder_fixture.sink_node},
64+
weight_function=EdgeCountWeightFunction(),
65+
max_path_weight=max_path_weight,
66+
):
67+
found_paths.append(found_path.copy())
68+
69+
max_weight_to_found_paths[max_path_weight] = found_paths.copy()
70+
71+
assert_object_snapshot_equal(
72+
request=pathfinder_fixture.request,
73+
snapshot_configuration=pathfinder_fixture.snapshot_configuration,
74+
obj={
75+
max_path_weight: sorted(found_paths) for max_path_weight, found_paths in max_weight_to_found_paths.items()
76+
},
77+
expectation_description="The dictionary shows the max. allowed path weight to the paths found.",
78+
)
79+
80+
81+
def test_find_ancestors(pathfinder_fixture: _PathFinderTestFixture) -> None: # noqa: D103
82+
find_ancestors_result = pathfinder_fixture.pathfinder.find_ancestors(
83+
graph=pathfinder_fixture.graph,
84+
source_nodes=FrozenOrderedSet((pathfinder_fixture.source_node,)),
85+
target_nodes=FrozenOrderedSet((pathfinder_fixture.sink_node,)),
86+
)
87+
assert_object_snapshot_equal(
88+
request=pathfinder_fixture.request,
89+
snapshot_configuration=pathfinder_fixture.snapshot_configuration,
90+
obj=find_ancestors_result,
91+
)
92+
93+
94+
def test_find_descendants(pathfinder_fixture: _PathFinderTestFixture) -> None: # noqa: D103
95+
find_descendants_result = pathfinder_fixture.pathfinder.find_descendants(
96+
graph=pathfinder_fixture.graph,
97+
source_nodes=FrozenOrderedSet((pathfinder_fixture.a_node, pathfinder_fixture.b_node)),
98+
target_nodes=FrozenOrderedSet((pathfinder_fixture.sink_node,)),
99+
)
100+
assert_object_snapshot_equal(
101+
request=pathfinder_fixture.request,
102+
snapshot_configuration=pathfinder_fixture.snapshot_configuration,
103+
obj=find_descendants_result,
104+
)

0 commit comments

Comments
 (0)