|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import logging |
| 4 | +from collections.abc import Iterable, Sequence |
| 5 | +from typing import Optional |
| 6 | + |
| 7 | +from dbt_semantic_interfaces.enum_extension import assert_values_exhausted |
| 8 | +from dbt_semantic_interfaces.protocols import Metric |
| 9 | +from dbt_semantic_interfaces.type_enums import MetricType |
| 10 | +from metricflow_semantics.semantic_graph.model_id import SemanticModelId |
| 11 | +from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec |
| 12 | +from metricflow_semantics.specs.metric_spec import MetricSpec |
| 13 | +from metricflow_semantics.specs.where_filter.where_filter_spec_factory import WhereFilterSpecFactory |
| 14 | +from metricflow_semantics.toolkit.mf_logging.lazy_formattable import LazyFormat |
| 15 | +from typing_extensions import override |
| 16 | + |
| 17 | +from metricflow.metric_evaluation.metric_query_planner import MetricEvaluationPlanner |
| 18 | +from metricflow.metric_evaluation.plan.me_edges import MetricQueryDependencyEdge |
| 19 | +from metricflow.metric_evaluation.plan.me_nodes import ( |
| 20 | + ConversionMetricQueryNode, |
| 21 | + CumulativeMetricQueryNode, |
| 22 | + DerivedMetricsQueryNode, |
| 23 | + MetricQueryNode, |
| 24 | + SimpleMetricsQueryNode, |
| 25 | + TopLevelQueryNode, |
| 26 | +) |
| 27 | +from metricflow.metric_evaluation.plan.me_plan import ( |
| 28 | + MetricEvaluationPlan, |
| 29 | + MutableMetricEvaluationPlan, |
| 30 | +) |
| 31 | +from metricflow.metric_evaluation.plan.query_element import MetricQueryElement, MetricQueryPropertySet |
| 32 | +from metricflow.plan_conversion.node_processor import PredicatePushdownState |
| 33 | + |
| 34 | +logger = logging.getLogger(__name__) |
| 35 | + |
| 36 | + |
| 37 | +class DepthFirstSearchMetricEvaluationPlanner(MetricEvaluationPlanner): |
| 38 | + """Builds a metric evaluation plan using a depth-first traversal of the metric dependency graph. |
| 39 | +
|
| 40 | + For example, the metric evaluation plan for the query [`bookings_per_listing`, `bookings`] results in a plan |
| 41 | + that has the following edges: |
| 42 | +
|
| 43 | + MetricQuery([`bookings_per_listing`]) -> MetricQuery([`bookings`]) |
| 44 | + MetricQuery([`bookings_per_listing`]) -> MetricQuery([`listings`]) |
| 45 | + Top Level Query -> MetricQuery([`bookings_per_listing`]) |
| 46 | + Top Level Query -> MetricQuery([`bookings`]) |
| 47 | +
|
| 48 | + This mirrors the original approach to compute metrics in the `DataflowPlanBuilder`. |
| 49 | + """ |
| 50 | + |
| 51 | + @override |
| 52 | + def build_plan( |
| 53 | + self, |
| 54 | + metric_specs: Sequence[MetricSpec], |
| 55 | + group_by_item_specs: Sequence[LinkableInstanceSpec], |
| 56 | + predicate_pushdown_state: PredicatePushdownState, |
| 57 | + filter_spec_factory: WhereFilterSpecFactory, |
| 58 | + ) -> MetricEvaluationPlan: |
| 59 | + """Build a metric evaluation plan using iterative depth-first traversal. |
| 60 | +
|
| 61 | + This resolves each requested metric into metric-query nodes and dependency edges, then attaches a top-level |
| 62 | + query node that references all requested metrics. |
| 63 | + """ |
| 64 | + top_level_query_elements = tuple( |
| 65 | + MetricQueryElement.create( |
| 66 | + metric_spec=metric_spec, |
| 67 | + group_by_item_specs=group_by_item_specs, |
| 68 | + predicate_pushdown_state=predicate_pushdown_state, |
| 69 | + ) |
| 70 | + for metric_spec in metric_specs |
| 71 | + ) |
| 72 | + |
| 73 | + evaluation_plan = MutableMetricEvaluationPlan.create() |
| 74 | + |
| 75 | + # The query elements to process in the iterative DFS traversal loop. The next element is popped from the right |
| 76 | + # so elements are added in reverse to preserve order. |
| 77 | + query_elements_to_process: list[MetricQueryElement] = list(reversed(top_level_query_elements)) |
| 78 | + # Keeps track of the query elements that have been processed into a node in the evaluation plan. |
| 79 | + query_element_to_node: dict[MetricQueryElement, MetricQueryNode] = {} |
| 80 | + |
| 81 | + while query_elements_to_process: |
| 82 | + current_query_element = query_elements_to_process.pop() |
| 83 | + logger.debug(LazyFormat("Handling query element", current_query_element=current_query_element)) |
| 84 | + if current_query_element in query_element_to_node: |
| 85 | + continue |
| 86 | + |
| 87 | + current_metric_spec = current_query_element.metric_spec |
| 88 | + current_query_properties = current_query_element.query_properties |
| 89 | + current_predicate_pushdown_state = current_query_element.predicate_pushdown_state |
| 90 | + |
| 91 | + metric_name = current_metric_spec.element_name |
| 92 | + metric = self._manifest_object_lookup.get_metric(metric_name) |
| 93 | + metric_type = metric.type |
| 94 | + |
| 95 | + # Handle non-derived metrics. |
| 96 | + metric_query_node = self._create_base_metric_query_node( |
| 97 | + metric=metric, |
| 98 | + metric_type=metric_type, |
| 99 | + metric_spec=current_metric_spec, |
| 100 | + query_properties=current_query_properties, |
| 101 | + ) |
| 102 | + if metric_query_node is not None: |
| 103 | + evaluation_plan.add_node(metric_query_node) |
| 104 | + query_element_to_node[current_query_element] = metric_query_node |
| 105 | + continue |
| 106 | + |
| 107 | + # Handle derived metrics. |
| 108 | + input_query_elements = self._get_input_metric_query_elements_for_derived_metric( |
| 109 | + metric_spec=current_metric_spec, |
| 110 | + group_by_item_specs=current_query_element.group_by_item_specs, |
| 111 | + predicate_pushdown_state=current_predicate_pushdown_state, |
| 112 | + filter_spec_factory=filter_spec_factory, |
| 113 | + ) |
| 114 | + assert len(input_query_elements) > 0, LazyFormat( |
| 115 | + "Expected a ratio or derived metric to have input query elements", |
| 116 | + current_metric_spec=current_metric_spec, |
| 117 | + metric=metric, |
| 118 | + ) |
| 119 | + |
| 120 | + inputs_that_need_processing = tuple( |
| 121 | + input_query_element |
| 122 | + for input_query_element in input_query_elements |
| 123 | + if input_query_element not in query_element_to_node |
| 124 | + ) |
| 125 | + # To implement DFS traversal, check if the input nodes have been processed. If not, add the input nodes |
| 126 | + # for processing and then try to process the current node again. |
| 127 | + if len(inputs_that_need_processing) > 0: |
| 128 | + # Add the current node first as the loop pops the next current element from the end. |
| 129 | + query_elements_to_process.append(current_query_element) |
| 130 | + # Adding inputs in reverse order to match traversal order with definition order. |
| 131 | + query_elements_to_process.extend(reversed(inputs_that_need_processing)) |
| 132 | + continue |
| 133 | + |
| 134 | + # All inputs of the derived metric have been processed, so add the node for the derived metric and the |
| 135 | + # edges. |
| 136 | + derived_metric_query_node = DerivedMetricsQueryNode.create( |
| 137 | + computed_metric_specs=[current_metric_spec], |
| 138 | + passthrough_metric_specs=(), |
| 139 | + query_properties=current_query_properties, |
| 140 | + ) |
| 141 | + evaluation_plan.add_node(derived_metric_query_node) |
| 142 | + |
| 143 | + for input_query_element in input_query_elements: |
| 144 | + input_query_node = query_element_to_node[input_query_element] |
| 145 | + evaluation_plan.add_edge( |
| 146 | + MetricQueryDependencyEdge.create( |
| 147 | + target_node=derived_metric_query_node, |
| 148 | + target_node_output_spec=current_metric_spec, |
| 149 | + source_node=input_query_node, |
| 150 | + source_node_output_spec=input_query_element.metric_spec, |
| 151 | + ) |
| 152 | + ) |
| 153 | + |
| 154 | + query_element_to_node[current_query_element] = derived_metric_query_node |
| 155 | + |
| 156 | + # Once nodes for all metrics in the query have been generated, add a `TopLevelQueryNode` to provide a single |
| 157 | + # entry point. |
| 158 | + top_level_query_node = TopLevelQueryNode.create( |
| 159 | + passthrough_metric_specs=metric_specs, |
| 160 | + query_properties=MetricQueryPropertySet.create(group_by_item_specs, predicate_pushdown_state), |
| 161 | + ) |
| 162 | + evaluation_plan.add_node(top_level_query_node) |
| 163 | + |
| 164 | + for top_level_query_element in top_level_query_elements: |
| 165 | + evaluation_plan.add_edge( |
| 166 | + MetricQueryDependencyEdge.create( |
| 167 | + target_node=top_level_query_node, |
| 168 | + target_node_output_spec=top_level_query_element.metric_spec, |
| 169 | + source_node=query_element_to_node[top_level_query_element], |
| 170 | + source_node_output_spec=top_level_query_element.metric_spec, |
| 171 | + ) |
| 172 | + ) |
| 173 | + |
| 174 | + return evaluation_plan |
| 175 | + |
| 176 | + def _create_base_metric_query_node( |
| 177 | + self, |
| 178 | + metric: Metric, |
| 179 | + metric_type: MetricType, |
| 180 | + metric_spec: MetricSpec, |
| 181 | + query_properties: MetricQueryPropertySet, |
| 182 | + ) -> Optional[MetricQueryNode]: |
| 183 | + """Return a node for base metric types or `None` for metrics that require dependency expansion.""" |
| 184 | + if metric_type is MetricType.SIMPLE: |
| 185 | + metric_aggregation_params = metric.type_params.metric_aggregation_params |
| 186 | + if metric_aggregation_params is None: |
| 187 | + raise ValueError( |
| 188 | + LazyFormat( |
| 189 | + "Simple metric is missing metric aggregation parameters", |
| 190 | + metric_spec=metric_spec, |
| 191 | + metric=metric, |
| 192 | + ) |
| 193 | + ) |
| 194 | + return SimpleMetricsQueryNode.create( |
| 195 | + model_id=SemanticModelId.get_instance(metric_aggregation_params.semantic_model), |
| 196 | + metric_specs=(metric_spec,), |
| 197 | + query_properties=query_properties, |
| 198 | + ) |
| 199 | + |
| 200 | + if metric_type is MetricType.CUMULATIVE: |
| 201 | + return CumulativeMetricQueryNode.create(metric_spec=metric_spec, query_properties=query_properties) |
| 202 | + elif metric_type is MetricType.CONVERSION: |
| 203 | + return ConversionMetricQueryNode.create(metric_spec=metric_spec, query_properties=query_properties) |
| 204 | + elif metric_type is MetricType.RATIO or metric_type is MetricType.DERIVED: |
| 205 | + return None |
| 206 | + else: |
| 207 | + assert_values_exhausted(metric_type) |
| 208 | + |
| 209 | + def _get_input_metric_query_elements_for_derived_metric( |
| 210 | + self, |
| 211 | + metric_spec: MetricSpec, |
| 212 | + group_by_item_specs: Iterable[LinkableInstanceSpec], |
| 213 | + predicate_pushdown_state: PredicatePushdownState, |
| 214 | + filter_spec_factory: WhereFilterSpecFactory, |
| 215 | + ) -> Sequence[MetricQueryElement]: |
| 216 | + """Return input query elements for a ratio / derived metric. |
| 217 | +
|
| 218 | + Input query elements generally inherit group-by and predicate settings from the metric being expanded. |
| 219 | + Time-offset metrics are handled differently - see appropriate section in the `DataflowPlanBuilder`. |
| 220 | + """ |
| 221 | + additional_filter_specs = metric_spec.where_filter_specs |
| 222 | + group_by_item_specs_for_inputs = group_by_item_specs |
| 223 | + predicate_pushdown_state_for_inputs = predicate_pushdown_state |
| 224 | + |
| 225 | + if metric_spec.has_time_offset: |
| 226 | + group_by_item_specs_for_inputs = self._required_group_by_items_for_inputs_to_a_time_offset_metric( |
| 227 | + queried_group_by_specs=group_by_item_specs, |
| 228 | + filter_specs=metric_spec.where_filter_specs, |
| 229 | + ) |
| 230 | + predicate_pushdown_state_for_inputs = PredicatePushdownState.with_pushdown_disabled() |
| 231 | + # If metric is offset, we'll apply where constraint after offset to avoid removing values |
| 232 | + # unexpectedly. Time constraint will be applied by INNER JOINing to time spine. |
| 233 | + # We may consider encapsulating this in pushdown state later, but as of this moment pushdown |
| 234 | + # is about post-join to pre-join for dimension access, and relies on the builder to collect |
| 235 | + # predicates from query and metric specs and make them available at simple-metric-input level. |
| 236 | + additional_filter_specs = () |
| 237 | + |
| 238 | + input_metric_specs = self._build_input_metric_specs_for_derived_metric( |
| 239 | + metric_name=metric_spec.element_name, |
| 240 | + filter_spec_factory=filter_spec_factory, |
| 241 | + additional_filter_specs=additional_filter_specs, |
| 242 | + ) |
| 243 | + return tuple( |
| 244 | + MetricQueryElement.create( |
| 245 | + metric_spec=input_metric_spec, |
| 246 | + group_by_item_specs=group_by_item_specs_for_inputs, |
| 247 | + predicate_pushdown_state=predicate_pushdown_state_for_inputs, |
| 248 | + ) |
| 249 | + for input_metric_spec in input_metric_specs |
| 250 | + ) |
0 commit comments