Skip to content

Commit 0b1cd6a

Browse files
authored
Add DFS metric evaluation planner (#1985)
This PR adds a metric evaluation planner that creates the plan using DFS traversal of the metric definition. This mirrors the current approach used in the dataflow plan builder.
1 parent 3c19f87 commit 0b1cd6a

19 files changed

+1082
-19
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
kind: Under the Hood
2+
body: Add DFS metric evaluation planner
3+
time: 2026-03-10T08:51:19.497736-07:00
4+
custom:
5+
Author: plypaul
6+
Issue: "1985"

metricflow-semantics/metricflow_semantics/toolkit/string_helpers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import textwrap
4+
from typing import Optional
45

56
MF_INDENT_2_SPACE = " "
67

@@ -36,3 +37,8 @@ def mf_dedent(text: str) -> str:
3637
)
3738
"""
3839
return textwrap.dedent(text.lstrip("\n")).rstrip("\n")
40+
41+
42+
def mf_wrap(text: str, width: Optional[int] = None) -> str:
43+
"""Wraps text to the specified width. Useful for user-facing messages."""
44+
return "\n".join(textwrap.wrap(text=text, width=width if width is not None else 80))

metricflow/dataflow/nodes/compute_metrics.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from collections.abc import Iterable
44
from dataclasses import dataclass
5+
from functools import cached_property
56
from typing import Sequence, Set, Tuple
67

78
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
@@ -67,11 +68,16 @@ def description(self) -> str: # noqa: D102
6768

6869
@property
6970
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
70-
displayed_properties = tuple(super().displayed_properties) + tuple(
71+
displayed_properties = list(super().displayed_properties)
72+
# TODO: Use different key names for computed / passthrough metric specs.
73+
displayed_properties.extend(
7174
DisplayedProperty("metric_spec", metric_spec) for metric_spec in self.computed_metric_specs
7275
)
76+
displayed_properties.extend(
77+
DisplayedProperty("metric_spec", metric_spec) for metric_spec in self.passthrough_metric_specs
78+
)
7379
if self.output_group_by_metric_instances:
74-
displayed_properties += (
80+
displayed_properties.append(
7581
DisplayedProperty("output_group_by_metric_instances", self.output_group_by_metric_instances),
7682
)
7783
return displayed_properties
@@ -84,7 +90,10 @@ def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa:
8490
if not isinstance(other_node, self.__class__):
8591
return False
8692

87-
if other_node.computed_metric_specs != self.computed_metric_specs:
93+
if (
94+
other_node.computed_metric_specs != self.computed_metric_specs
95+
or other_node.passthrough_metric_specs != self.passthrough_metric_specs
96+
):
8897
return False
8998

9099
return (
@@ -105,8 +114,8 @@ def can_combine(self, other_node: ComputeMetricsNode) -> Tuple[bool, str]:
105114

106115
if other_node.output_group_by_metric_instances != self.output_group_by_metric_instances:
107116
return False, "one node is a group by metric source node"
108-
109-
alias_to_metric_spec = {spec.alias: spec for spec in self.computed_metric_specs if spec.alias is not None}
117+
metric_specs = self.computed_metric_specs + self.passthrough_metric_specs
118+
alias_to_metric_spec = {spec.alias: spec for spec in metric_specs if spec.alias is not None}
110119

111120
for spec in other_node.computed_metric_specs:
112121
if (
@@ -135,3 +144,7 @@ def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> Comp
135144
@override
136145
def aggregated_to_elements(self) -> Set[LinkableInstanceSpec]:
137146
return set(self._aggregated_to_elements)
147+
148+
@cached_property
149+
def metric_specs(self) -> Sequence[MetricSpec]: # noqa: D102
150+
return self.computed_metric_specs + self.passthrough_metric_specs

metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

3+
import itertools
34
import logging
45
from dataclasses import dataclass
56
from typing import List, Optional, Sequence
67

7-
from metricflow_semantics.specs.metric_spec import MetricSpec
8+
from metricflow_semantics.toolkit.collections.ordered_set import FrozenOrderedSet
89
from metricflow_semantics.toolkit.mf_logging.lazy_formattable import LazyFormat
910

1011
from metricflow.dataflow.dataflow_plan import (
@@ -338,19 +339,18 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> ComputeMetrics
338339
combined_parent_node = combined_parent_nodes[0]
339340
assert combined_parent_node is not None
340341

341-
# Dedupe (preserving order for output consistency) as it's possible for multiple derived metrics to use the same
342-
# metric.
343-
unique_metric_specs: List[MetricSpec] = []
344-
for metric_spec in tuple(self._current_left_node.computed_metric_specs) + tuple(
345-
current_right_node.computed_metric_specs
346-
):
347-
if metric_spec not in unique_metric_specs:
348-
unique_metric_specs.append(metric_spec)
349-
350342
combined_node = ComputeMetricsNode.create(
351343
parent_node=combined_parent_node,
352-
computed_metric_specs=unique_metric_specs,
353-
passthrough_metric_specs=(),
344+
# Dedupe (preserving order for output consistency) as it's possible for multiple derived metrics to use the same
345+
# metric.
346+
computed_metric_specs=FrozenOrderedSet(
347+
itertools.chain(self._current_left_node.computed_metric_specs, current_right_node.computed_metric_specs)
348+
),
349+
passthrough_metric_specs=FrozenOrderedSet(
350+
itertools.chain(
351+
self._current_left_node.passthrough_metric_specs, current_right_node.passthrough_metric_specs
352+
)
353+
),
354354
aggregated_to_elements=current_right_node.aggregated_to_elements,
355355
output_group_by_metric_instances=current_right_node.output_group_by_metric_instances,
356356
)

metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> OptimizeBranch
172172
optimized_branch=ComputeMetricsNode.create(
173173
parent_node=optimized_parent_result.optimized_branch,
174174
computed_metric_specs=node.computed_metric_specs,
175-
passthrough_metric_specs=(),
175+
passthrough_metric_specs=node.passthrough_metric_specs,
176176
output_group_by_metric_instances=node.output_group_by_metric_instances,
177177
aggregated_to_elements=node.aggregated_to_elements,
178178
)
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
from collections.abc import Iterable, Sequence
5+
from typing import Optional
6+
7+
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
8+
from dbt_semantic_interfaces.protocols import Metric
9+
from dbt_semantic_interfaces.type_enums import MetricType
10+
from metricflow_semantics.semantic_graph.model_id import SemanticModelId
11+
from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec
12+
from metricflow_semantics.specs.metric_spec import MetricSpec
13+
from metricflow_semantics.specs.where_filter.where_filter_spec_factory import WhereFilterSpecFactory
14+
from metricflow_semantics.toolkit.mf_logging.lazy_formattable import LazyFormat
15+
from typing_extensions import override
16+
17+
from metricflow.metric_evaluation.metric_query_planner import MetricEvaluationPlanner
18+
from metricflow.metric_evaluation.plan.me_edges import MetricQueryDependencyEdge
19+
from metricflow.metric_evaluation.plan.me_nodes import (
20+
ConversionMetricQueryNode,
21+
CumulativeMetricQueryNode,
22+
DerivedMetricsQueryNode,
23+
MetricQueryNode,
24+
SimpleMetricsQueryNode,
25+
TopLevelQueryNode,
26+
)
27+
from metricflow.metric_evaluation.plan.me_plan import (
28+
MetricEvaluationPlan,
29+
MutableMetricEvaluationPlan,
30+
)
31+
from metricflow.metric_evaluation.plan.query_element import MetricQueryElement, MetricQueryPropertySet
32+
from metricflow.plan_conversion.node_processor import PredicatePushdownState
33+
34+
logger = logging.getLogger(__name__)
35+
36+
37+
class DepthFirstSearchMetricEvaluationPlanner(MetricEvaluationPlanner):
38+
"""Builds a metric evaluation plan using a depth-first traversal of the metric dependency graph.
39+
40+
For example, the metric evaluation plan for the query [`bookings_per_listing`, `bookings`] results in a plan
41+
that has the following edges:
42+
43+
MetricQuery([`bookings_per_listing`]) -> MetricQuery([`bookings`])
44+
MetricQuery([`bookings_per_listing`]) -> MetricQuery([`listings`])
45+
Top Level Query -> MetricQuery([`bookings_per_listing`])
46+
Top Level Query -> MetricQuery([`bookings`])
47+
48+
This mirrors the original approach to compute metrics in the `DataflowPlanBuilder`.
49+
"""
50+
51+
@override
52+
def build_plan(
53+
self,
54+
metric_specs: Sequence[MetricSpec],
55+
group_by_item_specs: Sequence[LinkableInstanceSpec],
56+
predicate_pushdown_state: PredicatePushdownState,
57+
filter_spec_factory: WhereFilterSpecFactory,
58+
) -> MetricEvaluationPlan:
59+
"""Build a metric evaluation plan using iterative depth-first traversal.
60+
61+
This resolves each requested metric into metric-query nodes and dependency edges, then attaches a top-level
62+
query node that references all requested metrics.
63+
"""
64+
top_level_query_elements = tuple(
65+
MetricQueryElement.create(
66+
metric_spec=metric_spec,
67+
group_by_item_specs=group_by_item_specs,
68+
predicate_pushdown_state=predicate_pushdown_state,
69+
)
70+
for metric_spec in metric_specs
71+
)
72+
73+
evaluation_plan = MutableMetricEvaluationPlan.create()
74+
75+
# The query elements to process in the iterative DFS traversal loop. The next element is popped from the right
76+
# so elements are added in reverse to preserve order.
77+
query_elements_to_process: list[MetricQueryElement] = list(reversed(top_level_query_elements))
78+
# Keeps track of the query elements that have been processed into a node in the evaluation plan.
79+
query_element_to_node: dict[MetricQueryElement, MetricQueryNode] = {}
80+
81+
while query_elements_to_process:
82+
current_query_element = query_elements_to_process.pop()
83+
logger.debug(LazyFormat("Handling query element", current_query_element=current_query_element))
84+
if current_query_element in query_element_to_node:
85+
continue
86+
87+
current_metric_spec = current_query_element.metric_spec
88+
current_query_properties = current_query_element.query_properties
89+
current_predicate_pushdown_state = current_query_element.predicate_pushdown_state
90+
91+
metric_name = current_metric_spec.element_name
92+
metric = self._manifest_object_lookup.get_metric(metric_name)
93+
metric_type = metric.type
94+
95+
# Handle non-derived metrics.
96+
metric_query_node = self._create_base_metric_query_node(
97+
metric=metric,
98+
metric_type=metric_type,
99+
metric_spec=current_metric_spec,
100+
query_properties=current_query_properties,
101+
)
102+
if metric_query_node is not None:
103+
evaluation_plan.add_node(metric_query_node)
104+
query_element_to_node[current_query_element] = metric_query_node
105+
continue
106+
107+
# Handle derived metrics.
108+
input_query_elements = self._get_input_metric_query_elements_for_derived_metric(
109+
metric_spec=current_metric_spec,
110+
group_by_item_specs=current_query_element.group_by_item_specs,
111+
predicate_pushdown_state=current_predicate_pushdown_state,
112+
filter_spec_factory=filter_spec_factory,
113+
)
114+
assert len(input_query_elements) > 0, LazyFormat(
115+
"Expected a ratio or derived metric to have input query elements",
116+
current_metric_spec=current_metric_spec,
117+
metric=metric,
118+
)
119+
120+
inputs_that_need_processing = tuple(
121+
input_query_element
122+
for input_query_element in input_query_elements
123+
if input_query_element not in query_element_to_node
124+
)
125+
# To implement DFS traversal, check if the input nodes have been processed. If not, add the input nodes
126+
# for processing and then try to process the current node again.
127+
if len(inputs_that_need_processing) > 0:
128+
# Add the current node first as the loop pops the next current element from the end.
129+
query_elements_to_process.append(current_query_element)
130+
# Adding inputs in reverse order to match traversal order with definition order.
131+
query_elements_to_process.extend(reversed(inputs_that_need_processing))
132+
continue
133+
134+
# All inputs of the derived metric have been processed, so add the node for the derived metric and the
135+
# edges.
136+
derived_metric_query_node = DerivedMetricsQueryNode.create(
137+
computed_metric_specs=[current_metric_spec],
138+
passthrough_metric_specs=(),
139+
query_properties=current_query_properties,
140+
)
141+
evaluation_plan.add_node(derived_metric_query_node)
142+
143+
for input_query_element in input_query_elements:
144+
input_query_node = query_element_to_node[input_query_element]
145+
evaluation_plan.add_edge(
146+
MetricQueryDependencyEdge.create(
147+
target_node=derived_metric_query_node,
148+
target_node_output_spec=current_metric_spec,
149+
source_node=input_query_node,
150+
source_node_output_spec=input_query_element.metric_spec,
151+
)
152+
)
153+
154+
query_element_to_node[current_query_element] = derived_metric_query_node
155+
156+
# Once nodes for all metrics in the query have been generated, add a `TopLevelQueryNode` to provide a single
157+
# entry point.
158+
top_level_query_node = TopLevelQueryNode.create(
159+
passthrough_metric_specs=metric_specs,
160+
query_properties=MetricQueryPropertySet.create(group_by_item_specs, predicate_pushdown_state),
161+
)
162+
evaluation_plan.add_node(top_level_query_node)
163+
164+
for top_level_query_element in top_level_query_elements:
165+
evaluation_plan.add_edge(
166+
MetricQueryDependencyEdge.create(
167+
target_node=top_level_query_node,
168+
target_node_output_spec=top_level_query_element.metric_spec,
169+
source_node=query_element_to_node[top_level_query_element],
170+
source_node_output_spec=top_level_query_element.metric_spec,
171+
)
172+
)
173+
174+
return evaluation_plan
175+
176+
def _create_base_metric_query_node(
177+
self,
178+
metric: Metric,
179+
metric_type: MetricType,
180+
metric_spec: MetricSpec,
181+
query_properties: MetricQueryPropertySet,
182+
) -> Optional[MetricQueryNode]:
183+
"""Return a node for base metric types or `None` for metrics that require dependency expansion."""
184+
if metric_type is MetricType.SIMPLE:
185+
metric_aggregation_params = metric.type_params.metric_aggregation_params
186+
if metric_aggregation_params is None:
187+
raise ValueError(
188+
LazyFormat(
189+
"Simple metric is missing metric aggregation parameters",
190+
metric_spec=metric_spec,
191+
metric=metric,
192+
)
193+
)
194+
return SimpleMetricsQueryNode.create(
195+
model_id=SemanticModelId.get_instance(metric_aggregation_params.semantic_model),
196+
metric_specs=(metric_spec,),
197+
query_properties=query_properties,
198+
)
199+
200+
if metric_type is MetricType.CUMULATIVE:
201+
return CumulativeMetricQueryNode.create(metric_spec=metric_spec, query_properties=query_properties)
202+
elif metric_type is MetricType.CONVERSION:
203+
return ConversionMetricQueryNode.create(metric_spec=metric_spec, query_properties=query_properties)
204+
elif metric_type is MetricType.RATIO or metric_type is MetricType.DERIVED:
205+
return None
206+
else:
207+
assert_values_exhausted(metric_type)
208+
209+
def _get_input_metric_query_elements_for_derived_metric(
210+
self,
211+
metric_spec: MetricSpec,
212+
group_by_item_specs: Iterable[LinkableInstanceSpec],
213+
predicate_pushdown_state: PredicatePushdownState,
214+
filter_spec_factory: WhereFilterSpecFactory,
215+
) -> Sequence[MetricQueryElement]:
216+
"""Return input query elements for a ratio / derived metric.
217+
218+
Input query elements generally inherit group-by and predicate settings from the metric being expanded.
219+
Time-offset metrics are handled differently - see appropriate section in the `DataflowPlanBuilder`.
220+
"""
221+
additional_filter_specs = metric_spec.where_filter_specs
222+
group_by_item_specs_for_inputs = group_by_item_specs
223+
predicate_pushdown_state_for_inputs = predicate_pushdown_state
224+
225+
if metric_spec.has_time_offset:
226+
group_by_item_specs_for_inputs = self._required_group_by_items_for_inputs_to_a_time_offset_metric(
227+
queried_group_by_specs=group_by_item_specs,
228+
filter_specs=metric_spec.where_filter_specs,
229+
)
230+
predicate_pushdown_state_for_inputs = PredicatePushdownState.with_pushdown_disabled()
231+
# If metric is offset, we'll apply where constraint after offset to avoid removing values
232+
# unexpectedly. Time constraint will be applied by INNER JOINing to time spine.
233+
# We may consider encapsulating this in pushdown state later, but as of this moment pushdown
234+
# is about post-join to pre-join for dimension access, and relies on the builder to collect
235+
# predicates from query and metric specs and make them available at simple-metric-input level.
236+
additional_filter_specs = ()
237+
238+
input_metric_specs = self._build_input_metric_specs_for_derived_metric(
239+
metric_name=metric_spec.element_name,
240+
filter_spec_factory=filter_spec_factory,
241+
additional_filter_specs=additional_filter_specs,
242+
)
243+
return tuple(
244+
MetricQueryElement.create(
245+
metric_spec=input_metric_spec,
246+
group_by_item_specs=group_by_item_specs_for_inputs,
247+
predicate_pushdown_state=predicate_pushdown_state_for_inputs,
248+
)
249+
for input_metric_spec in input_metric_specs
250+
)

0 commit comments

Comments
 (0)