|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import logging |
| 4 | + |
| 5 | +import pytest |
| 6 | +from _pytest.fixtures import FixtureRequest |
| 7 | +from metricflow_semantics.experimental.dataclass_helpers import fast_frozen_dataclass |
| 8 | +from metricflow_semantics.experimental.mf_graph.path_finding.graph_path import MutableGraphPath |
| 9 | +from metricflow_semantics.experimental.mf_graph.path_finding.pathfinder import MetricflowPathfinder |
| 10 | +from metricflow_semantics.experimental.mf_graph.path_finding.weight_function import EdgeCountWeightFunction |
| 11 | +from metricflow_semantics.experimental.ordered_set import FrozenOrderedSet |
| 12 | +from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration |
| 13 | +from metricflow_semantics.test_helpers.snapshot_helpers import assert_object_snapshot_equal |
| 14 | + |
| 15 | +from tests_metricflow_semantics.experimental.mf_graph.flow_graph import ( |
| 16 | + FlowEdge, |
| 17 | + FlowGraph, |
| 18 | + FlowGraphPath, |
| 19 | + FlowGraphPathFinder, |
| 20 | + FlowNode, |
| 21 | + IntermediateNode, |
| 22 | + SinkNode, |
| 23 | + SourceNode, |
| 24 | +) |
| 25 | + |
| 26 | +logger = logging.getLogger(__name__) |
| 27 | + |
| 28 | + |
| 29 | +@fast_frozen_dataclass() |
| 30 | +class _PathFinderTestFixture: |
| 31 | + request: FixtureRequest |
| 32 | + snapshot_configuration: MetricFlowTestConfiguration |
| 33 | + graph: FlowGraph |
| 34 | + path: MutableGraphPath[FlowNode, FlowEdge] |
| 35 | + pathfinder: FlowGraphPathFinder |
| 36 | + source_node: FlowNode = SourceNode.get_instance(node_name="source") |
| 37 | + sink_node: FlowNode = SinkNode.get_instance(node_name="sink") |
| 38 | + a_node: FlowNode = IntermediateNode.get_instance(node_name="a") |
| 39 | + b_node: FlowNode = IntermediateNode.get_instance(node_name="b") |
| 40 | + |
| 41 | + |
| 42 | +@pytest.fixture |
| 43 | +def pathfinder_fixture( # noqa: D103 |
| 44 | + request: FixtureRequest, mf_test_configuration: MetricFlowTestConfiguration, flow_graph: FlowGraph |
| 45 | +) -> _PathFinderTestFixture: |
| 46 | + return _PathFinderTestFixture( |
| 47 | + request=request, |
| 48 | + snapshot_configuration=mf_test_configuration, |
| 49 | + graph=flow_graph, |
| 50 | + pathfinder=MetricflowPathfinder(), |
| 51 | + path=MutableGraphPath.create(), |
| 52 | + ) |
| 53 | + |
| 54 | + |
| 55 | +def test_find_paths_dfs(pathfinder_fixture: _PathFinderTestFixture) -> None: # noqa: D103 |
| 56 | + max_weight_to_found_paths: dict[int, list[FlowGraphPath]] = {} |
| 57 | + |
| 58 | + for max_path_weight in range(0, 5): |
| 59 | + found_paths: list[FlowGraphPath] = [] |
| 60 | + for found_path in pathfinder_fixture.pathfinder.find_paths_dfs( |
| 61 | + graph=pathfinder_fixture.graph, |
| 62 | + initial_path=MutableGraphPath.create(pathfinder_fixture.source_node), |
| 63 | + target_nodes={pathfinder_fixture.sink_node}, |
| 64 | + weight_function=EdgeCountWeightFunction(), |
| 65 | + max_path_weight=max_path_weight, |
| 66 | + ): |
| 67 | + found_paths.append(found_path.copy()) |
| 68 | + |
| 69 | + max_weight_to_found_paths[max_path_weight] = found_paths.copy() |
| 70 | + |
| 71 | + assert_object_snapshot_equal( |
| 72 | + request=pathfinder_fixture.request, |
| 73 | + snapshot_configuration=pathfinder_fixture.snapshot_configuration, |
| 74 | + obj={ |
| 75 | + max_path_weight: sorted(found_paths) for max_path_weight, found_paths in max_weight_to_found_paths.items() |
| 76 | + }, |
| 77 | + expectation_description="The dictionary shows the max. allowed path weight to the paths found.", |
| 78 | + ) |
| 79 | + |
| 80 | + |
| 81 | +def test_find_ancestors(pathfinder_fixture: _PathFinderTestFixture) -> None: # noqa: D103 |
| 82 | + find_ancestors_result = pathfinder_fixture.pathfinder.find_ancestors( |
| 83 | + graph=pathfinder_fixture.graph, |
| 84 | + source_nodes=FrozenOrderedSet((pathfinder_fixture.source_node,)), |
| 85 | + target_nodes=FrozenOrderedSet((pathfinder_fixture.sink_node,)), |
| 86 | + ) |
| 87 | + assert_object_snapshot_equal( |
| 88 | + request=pathfinder_fixture.request, |
| 89 | + snapshot_configuration=pathfinder_fixture.snapshot_configuration, |
| 90 | + obj=find_ancestors_result, |
| 91 | + ) |
| 92 | + |
| 93 | + |
| 94 | +def test_find_descendants(pathfinder_fixture: _PathFinderTestFixture) -> None: # noqa: D103 |
| 95 | + find_descendants_result = pathfinder_fixture.pathfinder.find_descendants( |
| 96 | + graph=pathfinder_fixture.graph, |
| 97 | + source_nodes=FrozenOrderedSet((pathfinder_fixture.a_node, pathfinder_fixture.b_node)), |
| 98 | + target_nodes=FrozenOrderedSet((pathfinder_fixture.sink_node,)), |
| 99 | + ) |
| 100 | + assert_object_snapshot_equal( |
| 101 | + request=pathfinder_fixture.request, |
| 102 | + snapshot_configuration=pathfinder_fixture.snapshot_configuration, |
| 103 | + obj=find_descendants_result, |
| 104 | + ) |
0 commit comments