|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | from collections import defaultdict
|
| 4 | +from dataclasses import dataclass |
4 | 5 | from typing import Dict, FrozenSet, Mapping, Sequence, Set
|
5 | 6 |
|
| 7 | +from metricflow_semantics.collection_helpers.merger import Mergeable |
6 | 8 | from typing_extensions import override
|
7 | 9 |
|
8 | 10 | from metricflow.dataflow.dataflow_plan import DataflowPlan, DataflowPlanNode
|
9 | 11 | from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitorWithDefaultHandler
|
| 12 | +from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode |
10 | 13 |
|
11 | 14 |
|
12 | 15 | class DataflowPlanAnalyzer:
|
@@ -36,6 +39,12 @@ def find_common_branches(dataflow_plan: DataflowPlan) -> Sequence[DataflowPlanNo
|
36 | 39 |
|
37 | 40 | return tuple(sorted(dataflow_plan.sink_node.accept(common_branches_visitor)))
|
38 | 41 |
|
| 42 | + @staticmethod |
| 43 | + def group_nodes_by_type(dataflow_plan: DataflowPlan) -> DataflowPlanNodeSet: |
| 44 | + """Grouops dataflow plan nodes by type.""" |
| 45 | + grouping_visitor = _GroupNodesByTypeVisitor() |
| 46 | + return dataflow_plan.sink_node.accept(grouping_visitor) |
| 47 | + |
39 | 48 |
|
40 | 49 | class _CountDataflowNodeVisitor(DataflowPlanNodeVisitorWithDefaultHandler[None]):
|
41 | 50 | """Helper visitor to build a dict from a node in the plan to the number of times it appears in the plan."""
|
@@ -77,3 +86,41 @@ def _default_handler(self, node: DataflowPlanNode) -> FrozenSet[DataflowPlanNode
|
77 | 86 | common_branch_leaf_nodes.update(parent_node.accept(self))
|
78 | 87 |
|
79 | 88 | return frozenset(common_branch_leaf_nodes)
|
| 89 | + |
| 90 | + |
| 91 | +@dataclass(frozen=True) |
| 92 | +class DataflowPlanNodeSet(Mergeable): |
| 93 | + """Contains a set of dataflow plan nodes with fields for different types. |
| 94 | +
|
| 95 | + `ComputeMetricsNode` is the only node of interest for current use cases, but fields for other types can be added |
| 96 | + later. |
| 97 | + """ |
| 98 | + |
| 99 | + compute_metric_nodes: FrozenSet[ComputeMetricsNode] |
| 100 | + |
| 101 | + def merge(self, other: DataflowPlanNodeSet) -> DataflowPlanNodeSet: |
| 102 | + return DataflowPlanNodeSet( |
| 103 | + compute_metric_nodes=self.compute_metric_nodes.union(other.compute_metric_nodes), |
| 104 | + ) |
| 105 | + |
| 106 | + @classmethod |
| 107 | + def empty_instance(cls) -> DataflowPlanNodeSet: |
| 108 | + return DataflowPlanNodeSet( |
| 109 | + compute_metric_nodes=frozenset(), |
| 110 | + ) |
| 111 | + |
| 112 | + |
| 113 | +class _GroupNodesByTypeVisitor(DataflowPlanNodeVisitorWithDefaultHandler[DataflowPlanNodeSet]): |
| 114 | + """Groups dataflow nodes by type.""" |
| 115 | + |
| 116 | + @override |
| 117 | + def _default_handler(self, node: DataflowPlanNode) -> DataflowPlanNodeSet: |
| 118 | + node_sets = [] |
| 119 | + for parent_node in node.parent_nodes: |
| 120 | + node_sets.append(parent_node.accept(self)) |
| 121 | + |
| 122 | + return DataflowPlanNodeSet.merge_iterable(node_sets) |
| 123 | + |
| 124 | + @override |
| 125 | + def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> DataflowPlanNodeSet: |
| 126 | + return self._default_handler(node).merge(DataflowPlanNodeSet(frozenset({node}))) |
0 commit comments