Skip to content

Commit dadb857

Browse files
committed
/* PR_START p--cte 19 */ Add method to group nodes by type.
1 parent e1901ab commit dadb857

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

metricflow/dataflow/dataflow_plan_analyzer.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from __future__ import annotations
22

33
from collections import defaultdict
4+
from dataclasses import dataclass
45
from typing import Dict, FrozenSet, Mapping, Sequence, Set
56

7+
from metricflow_semantics.collection_helpers.merger import Mergeable
68
from typing_extensions import override
79

810
from metricflow.dataflow.dataflow_plan import DataflowPlan, DataflowPlanNode
911
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitorWithDefaultHandler
12+
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
1013

1114

1215
class DataflowPlanAnalyzer:
@@ -36,6 +39,12 @@ def find_common_branches(dataflow_plan: DataflowPlan) -> Sequence[DataflowPlanNo
3639

3740
return tuple(sorted(dataflow_plan.sink_node.accept(common_branches_visitor)))
3841

42+
@staticmethod
43+
def group_nodes_by_type(dataflow_plan: DataflowPlan) -> DataflowPlanNodeSet:
44+
"""Grouops dataflow plan nodes by type."""
45+
grouping_visitor = _GroupNodesByTypeVisitor()
46+
return dataflow_plan.sink_node.accept(grouping_visitor)
47+
3948

4049
class _CountDataflowNodeVisitor(DataflowPlanNodeVisitorWithDefaultHandler[None]):
4150
"""Helper visitor to build a dict from a node in the plan to the number of times it appears in the plan."""
@@ -77,3 +86,41 @@ def _default_handler(self, node: DataflowPlanNode) -> FrozenSet[DataflowPlanNode
7786
common_branch_leaf_nodes.update(parent_node.accept(self))
7887

7988
return frozenset(common_branch_leaf_nodes)
89+
90+
91+
@dataclass(frozen=True)
92+
class DataflowPlanNodeSet(Mergeable):
93+
"""Contains a set of dataflow plan nodes with fields for different types.
94+
95+
`ComputeMetricsNode` is the only node of interest for current use cases, but fields for other types can be added
96+
later.
97+
"""
98+
99+
compute_metric_nodes: FrozenSet[ComputeMetricsNode]
100+
101+
def merge(self, other: DataflowPlanNodeSet) -> DataflowPlanNodeSet:
102+
return DataflowPlanNodeSet(
103+
compute_metric_nodes=self.compute_metric_nodes.union(other.compute_metric_nodes),
104+
)
105+
106+
@classmethod
107+
def empty_instance(cls) -> DataflowPlanNodeSet:
108+
return DataflowPlanNodeSet(
109+
compute_metric_nodes=frozenset(),
110+
)
111+
112+
113+
class _GroupNodesByTypeVisitor(DataflowPlanNodeVisitorWithDefaultHandler[DataflowPlanNodeSet]):
114+
"""Groups dataflow nodes by type."""
115+
116+
@override
117+
def _default_handler(self, node: DataflowPlanNode) -> DataflowPlanNodeSet:
118+
node_sets = []
119+
for parent_node in node.parent_nodes:
120+
node_sets.append(parent_node.accept(self))
121+
122+
return DataflowPlanNodeSet.merge_iterable(node_sets)
123+
124+
@override
125+
def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> DataflowPlanNodeSet:
126+
return self._default_handler(node).merge(DataflowPlanNodeSet(frozenset({node})))

0 commit comments

Comments
 (0)