Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20250801-144219.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Add path weight-function for the semantic graph
time: 2025-08-01T14:42:19.117781-07:00
custom:
Author: plypaul
Issue: "1804"
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
from __future__ import annotations

import itertools
import logging
from abc import ABC, abstractmethod
from functools import cached_property
Expand All @@ -29,6 +30,7 @@
from metricflow_semantics.experimental.mf_graph.node_descriptor import MetricflowGraphNodeDescriptor
from metricflow_semantics.experimental.ordered_set import FrozenOrderedSet, MutableOrderedSet, OrderedSet
from metricflow_semantics.mf_logging.format_option import PrettyFormatOption
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from metricflow_semantics.mf_logging.pretty_formattable import MetricFlowPrettyFormattable
from metricflow_semantics.mf_logging.pretty_formatter import (
MetricFlowPrettyFormatter,
Expand Down Expand Up @@ -160,10 +162,24 @@ def nodes(self) -> OrderedSet[NodeT_co]:
raise NotImplementedError()

@abstractmethod
def nodes_with_label(self, graph_label: MetricflowGraphLabel) -> OrderedSet[NodeT_co]:
"""Return nodes in the graph with the given label."""
def nodes_with_labels(self, *graph_labels: MetricflowGraphLabel) -> OrderedSet[NodeT_co]:
"""Return nodes in the graph with any one of the given labels."""
raise NotImplementedError()

def node_with_label(self, label: MetricflowGraphLabel) -> NodeT_co:
"""Finds the node with the given label. If not exactly one if found, an error is raised."""
nodes = self.nodes_with_labels(label)
matching_node_count = len(nodes)
if matching_node_count != 1:
raise KeyError(
LazyFormat(
"Did not find exactly one node with the given label",
matching_node_count=matching_node_count,
first_10_nodes=list(itertools.islice(nodes, 10)),
)
)
return next(iter(nodes))

@property
@abstractmethod
def edges(self) -> OrderedSet[EdgeT_co]:
Expand Down Expand Up @@ -221,7 +237,7 @@ def _intersect_edges(self, other: MetricflowGraph[NodeT_co, EdgeT_co]) -> Ordere
return self.edges.intersection(other.edges)

@abstractmethod
def intersection(self, other: MetricflowGraph[NodeT_co, EdgeT_co]) -> Self: # noqa: D102
def intersection(self, other: Self) -> Self: # noqa: D102
raise NotImplementedError()

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from typing_extensions import override

from metricflow_semantics.collection_helpers.syntactic_sugar import mf_flatten
from metricflow_semantics.experimental.mf_graph.graph_id import MetricflowGraphId, SequentialGraphId
from metricflow_semantics.experimental.mf_graph.graph_labeling import MetricflowGraphLabel
from metricflow_semantics.experimental.mf_graph.mf_graph import (
Expand All @@ -15,7 +16,7 @@
MetricflowGraphNode,
NodeT,
)
from metricflow_semantics.experimental.ordered_set import MutableOrderedSet, OrderedSet
from metricflow_semantics.experimental.ordered_set import FrozenOrderedSet, MutableOrderedSet, OrderedSet

logger = logging.getLogger(__name__)

Expand All @@ -42,41 +43,41 @@ class MutableGraph(Generic[NodeT, EdgeT], MetricflowGraph[NodeT, EdgeT], ABC):
_node_to_successor_nodes: DefaultDict[MetricflowGraphNode, MutableOrderedSet[NodeT]]

def add_node(self, node: NodeT) -> None: # noqa: D102
self._nodes.add(node)
for node_property in node.labels:
self._label_to_nodes[node_property].add(node)
self._graph_id = SequentialGraphId.create()
self.add_nodes((node,))

def add_nodes(self, nodes: Iterable[NodeT]) -> None: # noqa: D102
self._nodes.update(nodes)
for node in nodes:
self.add_node(node)
for node_label in node.labels:
self._label_to_nodes[node_label].add(node)
self._graph_id = SequentialGraphId.create()

def add_edge(self, edge: EdgeT) -> None: # noqa: D102
tail_node = edge.tail_node
head_node = edge.head_node
graph_nodes = self._nodes

if tail_node not in graph_nodes:
self.add_node(tail_node)
if head_node not in graph_nodes:
self.add_node(head_node)

self._tail_node_to_edges[tail_node].add(edge)
self._head_node_to_edges[head_node].add(edge)
self._node_to_successor_nodes[tail_node].add(head_node)
self._node_to_predecessor_nodes[head_node].add(tail_node)
self._edges.add(edge)
self._graph_id = SequentialGraphId.create()
self.add_edges((edge,))

def add_edges(self, edges: Iterable[EdgeT]) -> None: # noqa: D102
tail_nodes = [edge.tail_node for edge in edges]
head_nodes = [edge.head_node for edge in edges]

nodes_to_add: MutableOrderedSet[NodeT] = MutableOrderedSet()
nodes_to_add.update(tail_nodes, head_nodes)
nodes_to_add.difference_update(self.nodes)
self.add_nodes(nodes_to_add)

for edge in edges:
self.add_edge(edge)
tail_node = edge.tail_node
head_node = edge.head_node

self._tail_node_to_edges[tail_node].add(edge)
self._head_node_to_edges[head_node].add(edge)
self._node_to_successor_nodes[tail_node].add(head_node)
self._node_to_predecessor_nodes[head_node].add(tail_node)

self._edges.update(edges)
self._graph_id = SequentialGraphId.create()

def update(self, other: MetricflowGraph[NodeT, EdgeT]) -> None:
"""Add the nodes and edges to this graph."""
if len(other.nodes) == 0 and len(other.edges) == 0:
return

self.add_nodes(other.nodes)
self.add_edges(other.edges)
self._graph_id = SequentialGraphId.create()
Expand All @@ -87,8 +88,8 @@ def nodes(self) -> OrderedSet[NodeT]: # noqa: D102
return self._nodes

@override
def nodes_with_label(self, graph_label: MetricflowGraphLabel) -> OrderedSet[NodeT]:
return self._label_to_nodes[graph_label]
def nodes_with_labels(self, *graph_labels: MetricflowGraphLabel) -> OrderedSet[NodeT]:
return FrozenOrderedSet(mf_flatten(self._label_to_nodes[label] for label in graph_labels))

@override
@property
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
from __future__ import annotations

import logging
from functools import cached_property
from typing import Optional, Set

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.type_enums import DatePart, TimeGranularity

from metricflow_semantics.collection_helpers.mf_type_aliases import AnyLengthTuple
from metricflow_semantics.experimental.dataclass_helpers import fast_frozen_dataclass
from metricflow_semantics.experimental.metricflow_exception import MetricflowInternalError
from metricflow_semantics.experimental.ordered_set import FrozenOrderedSet, MutableOrderedSet, OrderedSet
from metricflow_semantics.experimental.semantic_graph.attribute_resolution.attribute_recipe_step import (
AttributeRecipeStep,
)
from metricflow_semantics.experimental.semantic_graph.model_id import SemanticModelId
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from metricflow_semantics.model.linkable_element_property import LinkableElementProperty
from metricflow_semantics.model.semantics.linkable_element import LinkableElementType
from metricflow_semantics.time.granularity import ExpandedTimeGranularity

logger = logging.getLogger(__name__)


IndexedDunderName = AnyLengthTuple[str]


@fast_frozen_dataclass()
class AttributeRecipe:
"""The recipe for computing an attribute by following a path in the semantic graph."""

indexed_dunder_name: IndexedDunderName = ()
joined_model_ids: AnyLengthTuple[SemanticModelId] = ()
element_properties: FrozenOrderedSet[LinkableElementProperty] = FrozenOrderedSet()
entity_link_names: AnyLengthTuple[str] = ()

element_type: Optional[LinkableElementType] = None
# Maps to the time grain set for a time dimension in a semantic model
source_time_grain: Optional[TimeGranularity] = None
# Maps to the attribute's time grain or date part.
recipe_time_grain: Optional[ExpandedTimeGranularity] = None
recipe_date_part: Optional[DatePart] = None

@cached_property
def dunder_name_elements_set(self) -> Set[str]:
"""The elements of a dunder name as a set for fast repeated-element checks."""
return set(self.indexed_dunder_name)

@cached_property
def joined_model_id_set(self) -> Set[SemanticModelId]:
"""The joined semantic model IDs as a set for fast repeated-model checks."""
return set(self.joined_model_ids)

@staticmethod
def create(initial_step: AttributeRecipeStep) -> AttributeRecipe: # noqa: D102
dunder_name_elements: AnyLengthTuple[str] = ()
if initial_step.add_dunder_name_element is not None:
dunder_name_elements = (initial_step.add_dunder_name_element,)
entity_link_names: AnyLengthTuple[str] = ()
if initial_step.add_entity_link is not None:
entity_link_names = (initial_step.add_entity_link,)

models_in_join: AnyLengthTuple[SemanticModelId] = ()

add_model_join = initial_step.add_model_join
if add_model_join is not None:
models_in_join = models_in_join + (add_model_join,)

return AttributeRecipe(
indexed_dunder_name=dunder_name_elements,
joined_model_ids=models_in_join,
element_properties=FrozenOrderedSet(initial_step.add_properties or ()),
element_type=initial_step.set_element_type,
entity_link_names=entity_link_names,
source_time_grain=initial_step.set_source_time_grain,
recipe_time_grain=initial_step.set_time_grain_access,
recipe_date_part=initial_step.set_date_part_access,
)

@cached_property
def last_model_id(self) -> Optional[SemanticModelId]:
"""The last model ID that was added to the join."""
if self.joined_model_ids:
return None

return tuple(self.joined_model_ids)[-1]

def append_step(self, recipe_step: AttributeRecipeStep) -> AttributeRecipe:
"""Add a step to the end of the recipe."""
dundered_name_elements = self.indexed_dunder_name
if recipe_step.add_dunder_name_element is not None:
dundered_name_elements = dundered_name_elements + (recipe_step.add_dunder_name_element,)
entity_link_names = self.entity_link_names
if recipe_step.add_entity_link is not None:
entity_link_names = entity_link_names + (recipe_step.add_entity_link,)

models_in_join = self.joined_model_ids
join_model = recipe_step.add_model_join

if join_model is not None:
models_in_join = models_in_join + (join_model,)

return AttributeRecipe(
indexed_dunder_name=dundered_name_elements,
joined_model_ids=models_in_join,
element_properties=self.element_properties.union(recipe_step.add_properties)
if recipe_step.add_properties is not None
else self.element_properties,
element_type=recipe_step.set_element_type or self.element_type,
entity_link_names=entity_link_names,
source_time_grain=recipe_step.set_source_time_grain or self.source_time_grain,
recipe_time_grain=recipe_step.set_time_grain_access or self.recipe_time_grain,
recipe_date_part=recipe_step.set_date_part_access or self.recipe_date_part,
)

def push_step(self, recipe_step: AttributeRecipeStep) -> AttributeRecipe:
"""Add a step to the beginning of the recipe."""
dundered_name_elements = self.indexed_dunder_name
if recipe_step.add_dunder_name_element is not None:
dundered_name_elements = (recipe_step.add_dunder_name_element,) + dundered_name_elements
entity_link_names = self.entity_link_names
if recipe_step.add_entity_link is not None:
entity_link_names = (recipe_step.add_entity_link,) + entity_link_names
models_in_join = self.joined_model_ids
add_model_join = recipe_step.add_model_join
if add_model_join is not None:
if len(models_in_join) == 0:
models_in_join = (add_model_join,)
else:
models_in_join = (add_model_join,) + models_in_join

return AttributeRecipe(
indexed_dunder_name=dundered_name_elements,
joined_model_ids=models_in_join,
element_properties=FrozenOrderedSet(recipe_step.add_properties).union(self.element_properties)
if recipe_step.add_properties is not None
else self.element_properties,
element_type=self.element_type or recipe_step.set_element_type,
entity_link_names=entity_link_names,
source_time_grain=self.source_time_grain or recipe_step.set_source_time_grain,
recipe_time_grain=self.recipe_time_grain or recipe_step.set_time_grain_access,
recipe_date_part=self.recipe_date_part or recipe_step.set_date_part_access,
)

def push_steps(self, *updates: AttributeRecipeStep) -> AttributeRecipe:
"""See `push_step`."""
result = self
for update in updates:
result = result.push_step(update)
return result

def resolve_complete_properties(self) -> OrderedSet[LinkableElementProperty]:
"""Resolve the complete set of `LinkableElementProperty` for this recipe.

While many properties were set by recipe steps during traversal, some need to be resolved at the end as it
is easier / faster to determine at the end.
"""
element_type = self.element_type

if element_type is None:
raise ValueError(LazyFormat("Recipe is missing the element type", recipe=self))

properties = MutableOrderedSet(self.element_properties)

model_ids = self.joined_model_ids
model_id_count = len(model_ids)
if model_id_count == 0:
if LinkableElementProperty.METRIC_TIME not in properties:
raise ValueError(LazyFormat("Recipe is missing context on accessed semantic models", recipe=self))
elif model_id_count == 1:
if element_type is not LinkableElementType.METRIC and LinkableElementProperty.METRIC_TIME not in properties:
properties.add(LinkableElementProperty.LOCAL)
elif model_id_count == 2:
properties.add(LinkableElementProperty.JOINED)
elif model_id_count >= 3:
properties.update(
(
LinkableElementProperty.JOINED,
LinkableElementProperty.MULTI_HOP,
)
)
else:
raise MetricflowInternalError(
LazyFormat("Reached unhandled case", model_id_count=model_id_count, recipe=self)
)

# Add `DERIVED_TIME_GRANULARITY` if the grain is different from the element's grain.
source_time_grain = self.source_time_grain
recipe_time_grain = self.recipe_time_grain
if source_time_grain is not None:
if recipe_time_grain is None and self.recipe_date_part is None:
raise ValueError(
LazyFormat(
"Recipe has a source time-grain, but no recipe time-grain or recipe date-part", recipe=self
)
)
if recipe_time_grain is not None and source_time_grain is not recipe_time_grain.base_granularity:
properties.add(LinkableElementProperty.DERIVED_TIME_GRANULARITY)

return properties

def resolve_element_name(self) -> Optional[str]:
"""Resolve the element name.

Currently, the recipe stores the dunder-name elements (e.g. ["metric_time", "day"]), but not the element name.
Since the position of the element name in the list depends on the type of element, this method helps to resolve
that.
"""
element_type = self.element_type

# Incomplete recipe.
if element_type is None:
return None

dunder_name_elements = self.indexed_dunder_name
dunder_name_element_count = len(dunder_name_elements)

# Incomplete recipe.
if dunder_name_element_count == 0:
return None

if element_type is LinkableElementType.TIME_DIMENSION:
# e.g. ['metric_time']
if dunder_name_element_count == 1:
return dunder_name_elements[-1]
# e.g. ['metric_time', 'day']
else:
return dunder_name_elements[-2]
elif (
element_type is LinkableElementType.ENTITY
or element_type is LinkableElementType.DIMENSION
or element_type is LinkableElementType.TIME_DIMENSION
or element_type is LinkableElementType.METRIC
):
return dunder_name_elements[-1]
else:
assert_values_exhausted(element_type)
Loading