Skip to content

Commit 719c263

Browse files
authored
Add path weight-function for the semantic graph (#1804)
This PR adds semantic-graph-specific classes to model paths and path weights in the semantic graph. Paths in the semantic graph can be used to figure out the computation (the recipe) for attributes, and the weight function helps to limit traversal so that only valid paths are taken.
1 parent 757dfc6 commit 719c263

File tree

12 files changed

+1082
-39
lines changed

12 files changed

+1082
-39
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
kind: Under the Hood
2+
body: Add path weight-function for the semantic graph
3+
time: 2025-08-01T14:42:19.117781-07:00
4+
custom:
5+
Author: plypaul
6+
Issue: "1804"

metricflow-semantics/metricflow_semantics/experimental/mf_graph/mf_graph.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55
from __future__ import annotations
66

7+
import itertools
78
import logging
89
from abc import ABC, abstractmethod
910
from functools import cached_property
@@ -29,6 +30,7 @@
2930
from metricflow_semantics.experimental.mf_graph.node_descriptor import MetricflowGraphNodeDescriptor
3031
from metricflow_semantics.experimental.ordered_set import FrozenOrderedSet, MutableOrderedSet, OrderedSet
3132
from metricflow_semantics.mf_logging.format_option import PrettyFormatOption
33+
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
3234
from metricflow_semantics.mf_logging.pretty_formattable import MetricFlowPrettyFormattable
3335
from metricflow_semantics.mf_logging.pretty_formatter import (
3436
MetricFlowPrettyFormatter,
@@ -160,10 +162,24 @@ def nodes(self) -> OrderedSet[NodeT_co]:
160162
raise NotImplementedError()
161163

162164
@abstractmethod
163-
def nodes_with_label(self, graph_label: MetricflowGraphLabel) -> OrderedSet[NodeT_co]:
164-
"""Return nodes in the graph with the given label."""
165+
def nodes_with_labels(self, *graph_labels: MetricflowGraphLabel) -> OrderedSet[NodeT_co]:
166+
"""Return nodes in the graph with any one of the given labels."""
165167
raise NotImplementedError()
166168

169+
def node_with_label(self, label: MetricflowGraphLabel) -> NodeT_co:
170+
"""Finds the node with the given label. If not exactly one if found, an error is raised."""
171+
nodes = self.nodes_with_labels(label)
172+
matching_node_count = len(nodes)
173+
if matching_node_count != 1:
174+
raise KeyError(
175+
LazyFormat(
176+
"Did not find exactly one node with the given label",
177+
matching_node_count=matching_node_count,
178+
first_10_nodes=list(itertools.islice(nodes, 10)),
179+
)
180+
)
181+
return next(iter(nodes))
182+
167183
@property
168184
@abstractmethod
169185
def edges(self) -> OrderedSet[EdgeT_co]:
@@ -221,7 +237,7 @@ def _intersect_edges(self, other: MetricflowGraph[NodeT_co, EdgeT_co]) -> Ordere
221237
return self.edges.intersection(other.edges)
222238

223239
@abstractmethod
224-
def intersection(self, other: MetricflowGraph[NodeT_co, EdgeT_co]) -> Self: # noqa: D102
240+
def intersection(self, other: Self) -> Self: # noqa: D102
225241
raise NotImplementedError()
226242

227243
@abstractmethod

metricflow-semantics/metricflow_semantics/experimental/mf_graph/mutable_graph.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from typing_extensions import override
99

10+
from metricflow_semantics.collection_helpers.syntactic_sugar import mf_flatten
1011
from metricflow_semantics.experimental.mf_graph.graph_id import MetricflowGraphId, SequentialGraphId
1112
from metricflow_semantics.experimental.mf_graph.graph_labeling import MetricflowGraphLabel
1213
from metricflow_semantics.experimental.mf_graph.mf_graph import (
@@ -15,7 +16,7 @@
1516
MetricflowGraphNode,
1617
NodeT,
1718
)
18-
from metricflow_semantics.experimental.ordered_set import MutableOrderedSet, OrderedSet
19+
from metricflow_semantics.experimental.ordered_set import FrozenOrderedSet, MutableOrderedSet, OrderedSet
1920

2021
logger = logging.getLogger(__name__)
2122

@@ -42,41 +43,41 @@ class MutableGraph(Generic[NodeT, EdgeT], MetricflowGraph[NodeT, EdgeT], ABC):
4243
_node_to_successor_nodes: DefaultDict[MetricflowGraphNode, MutableOrderedSet[NodeT]]
4344

4445
def add_node(self, node: NodeT) -> None: # noqa: D102
45-
self._nodes.add(node)
46-
for node_property in node.labels:
47-
self._label_to_nodes[node_property].add(node)
48-
self._graph_id = SequentialGraphId.create()
46+
self.add_nodes((node,))
4947

5048
def add_nodes(self, nodes: Iterable[NodeT]) -> None: # noqa: D102
49+
self._nodes.update(nodes)
5150
for node in nodes:
52-
self.add_node(node)
51+
for node_label in node.labels:
52+
self._label_to_nodes[node_label].add(node)
53+
self._graph_id = SequentialGraphId.create()
5354

5455
def add_edge(self, edge: EdgeT) -> None: # noqa: D102
55-
tail_node = edge.tail_node
56-
head_node = edge.head_node
57-
graph_nodes = self._nodes
58-
59-
if tail_node not in graph_nodes:
60-
self.add_node(tail_node)
61-
if head_node not in graph_nodes:
62-
self.add_node(head_node)
63-
64-
self._tail_node_to_edges[tail_node].add(edge)
65-
self._head_node_to_edges[head_node].add(edge)
66-
self._node_to_successor_nodes[tail_node].add(head_node)
67-
self._node_to_predecessor_nodes[head_node].add(tail_node)
68-
self._edges.add(edge)
69-
self._graph_id = SequentialGraphId.create()
56+
self.add_edges((edge,))
7057

7158
def add_edges(self, edges: Iterable[EdgeT]) -> None: # noqa: D102
59+
tail_nodes = [edge.tail_node for edge in edges]
60+
head_nodes = [edge.head_node for edge in edges]
61+
62+
nodes_to_add: MutableOrderedSet[NodeT] = MutableOrderedSet()
63+
nodes_to_add.update(tail_nodes, head_nodes)
64+
nodes_to_add.difference_update(self.nodes)
65+
self.add_nodes(nodes_to_add)
66+
7267
for edge in edges:
73-
self.add_edge(edge)
68+
tail_node = edge.tail_node
69+
head_node = edge.head_node
70+
71+
self._tail_node_to_edges[tail_node].add(edge)
72+
self._head_node_to_edges[head_node].add(edge)
73+
self._node_to_successor_nodes[tail_node].add(head_node)
74+
self._node_to_predecessor_nodes[head_node].add(tail_node)
75+
76+
self._edges.update(edges)
77+
self._graph_id = SequentialGraphId.create()
7478

7579
def update(self, other: MetricflowGraph[NodeT, EdgeT]) -> None:
7680
"""Add the nodes and edges to this graph."""
77-
if len(other.nodes) == 0 and len(other.edges) == 0:
78-
return
79-
8081
self.add_nodes(other.nodes)
8182
self.add_edges(other.edges)
8283
self._graph_id = SequentialGraphId.create()
@@ -87,8 +88,8 @@ def nodes(self) -> OrderedSet[NodeT]: # noqa: D102
8788
return self._nodes
8889

8990
@override
90-
def nodes_with_label(self, graph_label: MetricflowGraphLabel) -> OrderedSet[NodeT]:
91-
return self._label_to_nodes[graph_label]
91+
def nodes_with_labels(self, *graph_labels: MetricflowGraphLabel) -> OrderedSet[NodeT]:
92+
return FrozenOrderedSet(mf_flatten(self._label_to_nodes[label] for label in graph_labels))
9293

9394
@override
9495
@property
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
from functools import cached_property
5+
from typing import Optional, Set
6+
7+
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
8+
from dbt_semantic_interfaces.type_enums import DatePart, TimeGranularity
9+
10+
from metricflow_semantics.collection_helpers.mf_type_aliases import AnyLengthTuple
11+
from metricflow_semantics.experimental.dataclass_helpers import fast_frozen_dataclass
12+
from metricflow_semantics.experimental.metricflow_exception import MetricflowInternalError
13+
from metricflow_semantics.experimental.ordered_set import FrozenOrderedSet, MutableOrderedSet, OrderedSet
14+
from metricflow_semantics.experimental.semantic_graph.attribute_resolution.attribute_recipe_step import (
15+
AttributeRecipeStep,
16+
)
17+
from metricflow_semantics.experimental.semantic_graph.model_id import SemanticModelId
18+
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
19+
from metricflow_semantics.model.linkable_element_property import LinkableElementProperty
20+
from metricflow_semantics.model.semantics.linkable_element import LinkableElementType
21+
from metricflow_semantics.time.granularity import ExpandedTimeGranularity
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
IndexedDunderName = AnyLengthTuple[str]
27+
28+
29+
@fast_frozen_dataclass()
30+
class AttributeRecipe:
31+
"""The recipe for computing an attribute by following a path in the semantic graph."""
32+
33+
indexed_dunder_name: IndexedDunderName = ()
34+
joined_model_ids: AnyLengthTuple[SemanticModelId] = ()
35+
element_properties: FrozenOrderedSet[LinkableElementProperty] = FrozenOrderedSet()
36+
entity_link_names: AnyLengthTuple[str] = ()
37+
38+
element_type: Optional[LinkableElementType] = None
39+
# Maps to the time grain set for a time dimension in a semantic model
40+
source_time_grain: Optional[TimeGranularity] = None
41+
# Maps to the attribute's time grain or date part.
42+
recipe_time_grain: Optional[ExpandedTimeGranularity] = None
43+
recipe_date_part: Optional[DatePart] = None
44+
45+
@cached_property
46+
def dunder_name_elements_set(self) -> Set[str]:
47+
"""The elements of a dunder name as a set for fast repeated-element checks."""
48+
return set(self.indexed_dunder_name)
49+
50+
@cached_property
51+
def joined_model_id_set(self) -> Set[SemanticModelId]:
52+
"""The joined semantic model IDs as a set for fast repeated-model checks."""
53+
return set(self.joined_model_ids)
54+
55+
@staticmethod
56+
def create(initial_step: AttributeRecipeStep) -> AttributeRecipe: # noqa: D102
57+
dunder_name_elements: AnyLengthTuple[str] = ()
58+
if initial_step.add_dunder_name_element is not None:
59+
dunder_name_elements = (initial_step.add_dunder_name_element,)
60+
entity_link_names: AnyLengthTuple[str] = ()
61+
if initial_step.add_entity_link is not None:
62+
entity_link_names = (initial_step.add_entity_link,)
63+
64+
models_in_join: AnyLengthTuple[SemanticModelId] = ()
65+
66+
add_model_join = initial_step.add_model_join
67+
if add_model_join is not None:
68+
models_in_join = models_in_join + (add_model_join,)
69+
70+
return AttributeRecipe(
71+
indexed_dunder_name=dunder_name_elements,
72+
joined_model_ids=models_in_join,
73+
element_properties=FrozenOrderedSet(initial_step.add_properties or ()),
74+
element_type=initial_step.set_element_type,
75+
entity_link_names=entity_link_names,
76+
source_time_grain=initial_step.set_source_time_grain,
77+
recipe_time_grain=initial_step.set_time_grain_access,
78+
recipe_date_part=initial_step.set_date_part_access,
79+
)
80+
81+
@cached_property
82+
def last_model_id(self) -> Optional[SemanticModelId]:
83+
"""The last model ID that was added to the join."""
84+
if self.joined_model_ids:
85+
return None
86+
87+
return tuple(self.joined_model_ids)[-1]
88+
89+
def append_step(self, recipe_step: AttributeRecipeStep) -> AttributeRecipe:
90+
"""Add a step to the end of the recipe."""
91+
dundered_name_elements = self.indexed_dunder_name
92+
if recipe_step.add_dunder_name_element is not None:
93+
dundered_name_elements = dundered_name_elements + (recipe_step.add_dunder_name_element,)
94+
entity_link_names = self.entity_link_names
95+
if recipe_step.add_entity_link is not None:
96+
entity_link_names = entity_link_names + (recipe_step.add_entity_link,)
97+
98+
models_in_join = self.joined_model_ids
99+
join_model = recipe_step.add_model_join
100+
101+
if join_model is not None:
102+
models_in_join = models_in_join + (join_model,)
103+
104+
return AttributeRecipe(
105+
indexed_dunder_name=dundered_name_elements,
106+
joined_model_ids=models_in_join,
107+
element_properties=self.element_properties.union(recipe_step.add_properties)
108+
if recipe_step.add_properties is not None
109+
else self.element_properties,
110+
element_type=recipe_step.set_element_type or self.element_type,
111+
entity_link_names=entity_link_names,
112+
source_time_grain=recipe_step.set_source_time_grain or self.source_time_grain,
113+
recipe_time_grain=recipe_step.set_time_grain_access or self.recipe_time_grain,
114+
recipe_date_part=recipe_step.set_date_part_access or self.recipe_date_part,
115+
)
116+
117+
def push_step(self, recipe_step: AttributeRecipeStep) -> AttributeRecipe:
118+
"""Add a step to the beginning of the recipe."""
119+
dundered_name_elements = self.indexed_dunder_name
120+
if recipe_step.add_dunder_name_element is not None:
121+
dundered_name_elements = (recipe_step.add_dunder_name_element,) + dundered_name_elements
122+
entity_link_names = self.entity_link_names
123+
if recipe_step.add_entity_link is not None:
124+
entity_link_names = (recipe_step.add_entity_link,) + entity_link_names
125+
models_in_join = self.joined_model_ids
126+
add_model_join = recipe_step.add_model_join
127+
if add_model_join is not None:
128+
if len(models_in_join) == 0:
129+
models_in_join = (add_model_join,)
130+
else:
131+
models_in_join = (add_model_join,) + models_in_join
132+
133+
return AttributeRecipe(
134+
indexed_dunder_name=dundered_name_elements,
135+
joined_model_ids=models_in_join,
136+
element_properties=FrozenOrderedSet(recipe_step.add_properties).union(self.element_properties)
137+
if recipe_step.add_properties is not None
138+
else self.element_properties,
139+
element_type=self.element_type or recipe_step.set_element_type,
140+
entity_link_names=entity_link_names,
141+
source_time_grain=self.source_time_grain or recipe_step.set_source_time_grain,
142+
recipe_time_grain=self.recipe_time_grain or recipe_step.set_time_grain_access,
143+
recipe_date_part=self.recipe_date_part or recipe_step.set_date_part_access,
144+
)
145+
146+
def push_steps(self, *updates: AttributeRecipeStep) -> AttributeRecipe:
147+
"""See `push_step`."""
148+
result = self
149+
for update in updates:
150+
result = result.push_step(update)
151+
return result
152+
153+
def resolve_complete_properties(self) -> OrderedSet[LinkableElementProperty]:
154+
"""Resolve the complete set of `LinkableElementProperty` for this recipe.
155+
156+
While many properties were set by recipe steps during traversal, some need to be resolved at the end as it
157+
is easier / faster to determine at the end.
158+
"""
159+
element_type = self.element_type
160+
161+
if element_type is None:
162+
raise ValueError(LazyFormat("Recipe is missing the element type", recipe=self))
163+
164+
properties = MutableOrderedSet(self.element_properties)
165+
166+
model_ids = self.joined_model_ids
167+
model_id_count = len(model_ids)
168+
if model_id_count == 0:
169+
if LinkableElementProperty.METRIC_TIME not in properties:
170+
raise ValueError(LazyFormat("Recipe is missing context on accessed semantic models", recipe=self))
171+
elif model_id_count == 1:
172+
if element_type is not LinkableElementType.METRIC and LinkableElementProperty.METRIC_TIME not in properties:
173+
properties.add(LinkableElementProperty.LOCAL)
174+
elif model_id_count == 2:
175+
properties.add(LinkableElementProperty.JOINED)
176+
elif model_id_count >= 3:
177+
properties.update(
178+
(
179+
LinkableElementProperty.JOINED,
180+
LinkableElementProperty.MULTI_HOP,
181+
)
182+
)
183+
else:
184+
raise MetricflowInternalError(
185+
LazyFormat("Reached unhandled case", model_id_count=model_id_count, recipe=self)
186+
)
187+
188+
# Add `DERIVED_TIME_GRANULARITY` if the grain is different from the element's grain.
189+
source_time_grain = self.source_time_grain
190+
recipe_time_grain = self.recipe_time_grain
191+
if source_time_grain is not None:
192+
if recipe_time_grain is None and self.recipe_date_part is None:
193+
raise ValueError(
194+
LazyFormat(
195+
"Recipe has a source time-grain, but no recipe time-grain or recipe date-part", recipe=self
196+
)
197+
)
198+
if recipe_time_grain is not None and source_time_grain is not recipe_time_grain.base_granularity:
199+
properties.add(LinkableElementProperty.DERIVED_TIME_GRANULARITY)
200+
201+
return properties
202+
203+
def resolve_element_name(self) -> Optional[str]:
204+
"""Resolve the element name.
205+
206+
Currently, the recipe stores the dunder-name elements (e.g. ["metric_time", "day"]), but not the element name.
207+
Since the position of the element name in the list depends on the type of element, this method helps to resolve
208+
that.
209+
"""
210+
element_type = self.element_type
211+
212+
# Incomplete recipe.
213+
if element_type is None:
214+
return None
215+
216+
dunder_name_elements = self.indexed_dunder_name
217+
dunder_name_element_count = len(dunder_name_elements)
218+
219+
# Incomplete recipe.
220+
if dunder_name_element_count == 0:
221+
return None
222+
223+
if element_type is LinkableElementType.TIME_DIMENSION:
224+
# e.g. ['metric_time']
225+
if dunder_name_element_count == 1:
226+
return dunder_name_elements[-1]
227+
# e.g. ['metric_time', 'day']
228+
else:
229+
return dunder_name_elements[-2]
230+
elif (
231+
element_type is LinkableElementType.ENTITY
232+
or element_type is LinkableElementType.DIMENSION
233+
or element_type is LinkableElementType.TIME_DIMENSION
234+
or element_type is LinkableElementType.METRIC
235+
):
236+
return dunder_name_elements[-1]
237+
else:
238+
assert_values_exhausted(element_type)

0 commit comments

Comments
 (0)