Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20260107-134853.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Support explain-SQL tests using external manifests
time: 2026-01-07T13:48:53.725182-08:00
custom:
Author: plypaul
Issue: "1952"
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from __future__ import annotations

from pathlib import Path
from typing import Sequence

from dbt_semantic_interfaces.implementations.semantic_manifest import PydanticSemanticManifest
from dbt_semantic_interfaces.transformations.pydantic_rule_set import PydanticSemanticManifestTransformRuleSet
from dbt_semantic_interfaces.transformations.semantic_manifest_transformer import PydanticSemanticManifestTransformer
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.test_helpers.manifest_helpers import mf_load_manifest_from_json_file
from metricflow_semantics.toolkit.mf_type_aliases import Pair
from typing_extensions import override

from tests_metricflow.release_validation.manifest_setup.manifest_setup import ManifestSetup, ManifestSetupSource
from tests_metricflow.release_validation.manifest_transforms.modify_time_spine import ModifyTimeSpineTableRule
from tests_metricflow.release_validation.manifest_transforms.normalize_sql import NormalizeSqlRule
from tests_metricflow.table_snapshot.table_snapshots import (
SqlTableColumnDefinition,
SqlTableColumnType,
SqlTableSnapshot,
)


class ExternalManifestSetupSource(ManifestSetupSource):
"""Provides setups from a directory containing external manifests (JSON-serialized).

Example use case: "Here are a bunch of JSON-serialized manifests from customers. Check to see that a new release
of MF doesn't break any of their saved queries."

Since manifests from external sources may be authored using engine-specific SQL and rely on the existence of
specific tables, several manifest transformations are required so that the manifest can be used with
`DuckDbExplainTester`.

In general, references to specific SQL tables and user-defined SQL must be replaced with similar dummy values.
Since dummy values are used, the tests using the modified manifest won't be completely faithful, but can still
capture some potential errors.

Please see the associated transform rules for details.
"""

def __init__(self, manifest_directory: Path) -> None:
"""Initializer.

Args:
manifest_directory: Directory containing `*.json` files that represent serialized manifests.
"""
self._manifest_directory = manifest_directory
self._dummy_table_name = "dummy_table"
self._time_spine_table_name = "time_spine"

@override
def get_manifest_setups(self) -> Sequence[ManifestSetup]:
setups = []
for manifest_name, semantic_manifest in self._find_manifests():
schema_name = manifest_name
dummy_table = SqlTable(schema_name=schema_name, table_name=self._dummy_table_name)
dummy_table_snapshot = SqlTableSnapshot(
table_name=dummy_table.table_name,
schema_name=dummy_table.schema_name,
column_definitions=(SqlTableColumnDefinition(name="int_column", type=SqlTableColumnType.INT),),
rows=(("1",),),
file_path=None,
)
time_spine_table = SqlTable(schema_name=schema_name, table_name=self._time_spine_table_name)
time_spine_table_column_names = self._get_time_spine_column_names(semantic_manifest)
time_spine_table_snapshot = SqlTableSnapshot(
table_name=time_spine_table.table_name,
schema_name=schema_name,
column_definitions=tuple(
SqlTableColumnDefinition(name=column_name, type=SqlTableColumnType.TIME)
for column_name in time_spine_table_column_names
),
rows=(tuple("2020-01-01" for _ in time_spine_table_column_names),),
file_path=None,
)
setups.append(
ManifestSetup(
manifest_name=manifest_name,
semantic_manifest=semantic_manifest,
table_snapshots=(dummy_table_snapshot, time_spine_table_snapshot),
)
)
return setups

@staticmethod
def _get_time_spine_column_names(semantic_manifest: PydanticSemanticManifest) -> Sequence[str]:
column_names = set()
for time_spine_table_configuration in semantic_manifest.project_configuration.time_spine_table_configurations:
column_names.add(time_spine_table_configuration.column_name)

for time_spine in semantic_manifest.project_configuration.time_spines:
column_names.add(time_spine.primary_column.name)
for custom_grain in time_spine.custom_granularities:
column_names.add(custom_grain.column_name or custom_grain.name)

return sorted(column_names)

def _load_manifest(self, manifest_name: str, manifest_path: Path) -> PydanticSemanticManifest:
semantic_manifest = mf_load_manifest_from_json_file(manifest_path)
rule_set = PydanticSemanticManifestTransformRuleSet().all_rules

schema_name = manifest_name
primary_rules = (
NormalizeSqlRule(SqlTable(schema_name=schema_name, table_name=self._dummy_table_name)),
ModifyTimeSpineTableRule(schema_name=schema_name, table_name=self._time_spine_table_name),
) + tuple(rule_set[0])
secondary_rules = tuple(rule_set[1])

return PydanticSemanticManifestTransformer.transform(
semantic_manifest,
ordered_rule_sequences=(
primary_rules,
secondary_rules,
),
)

def _find_manifests(self) -> Sequence[Pair[str, PydanticSemanticManifest]]:
"""Return pairs that group the manifest name and the associated manifest."""
name_and_manifest_pairs = []
for manifest_path in self._manifest_directory.rglob("*.json"):
if manifest_path.stat().st_size == 0:
continue

manifest_name = manifest_path.name.replace(".json", "")
name_and_manifest_pairs.append(
(
manifest_name,
self._load_manifest(manifest_name, manifest_path),
)
)
return name_and_manifest_pairs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from __future__ import annotations

import copy
import logging
import re
from collections.abc import Sequence
from typing import Optional

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilterIntersection
from dbt_semantic_interfaces.implementations.node_relation import PydanticNodeRelation
from dbt_semantic_interfaces.implementations.semantic_manifest import PydanticSemanticManifest
from dbt_semantic_interfaces.transformations.transform_rule import SemanticManifestTransformRule
from dbt_semantic_interfaces.type_enums import DimensionType, MetricType
from metricflow_semantics.sql.sql_table import SqlTable
from typing_extensions import override

logger = logging.getLogger(__name__)


class NormalizeSqlRule(SemanticManifestTransformRule[PydanticSemanticManifest]):
"""Replace user-specified SQL in the manifest with SQL that works with DuckDB.

This rule replaces user-specified SQL in fields such as:

* Node relations for semantic models.
* Element expressions.
* Filters.
* Derived metric expressions.

with dummy SQL that is compatible with DuckDB. This enables validation (to a degree) using a local DuckDB instance.
"""

def __init__(self, dummy_table: SqlTable) -> None: # noqa: D107
self._dummy_table_node_relation = PydanticNodeRelation.from_string(dummy_table.sql)

@staticmethod
def _extract_variable_expression(jinja_template_str: str) -> Sequence[str]:
"""Return the variable expressions in the Jinja template.

For example:
`{{ user.first_name }} AND {{ user.last_name }}` -> ["{{ user.first_name }}", "{{ user.last_name}}"]
"""
# Pattern to match {{ ... }} expressions, capturing everything between {{ and }}
pattern = r"\{\{[^}]*\}\}"
matches = re.findall(pattern, jinja_template_str)
return matches

def _update_filter(self, filter_intersection: Optional[PydanticWhereFilterIntersection]) -> None:
"""Replace the user filter SQL with a similar one that can run on DuckDB.

The user filter SQL is replaced by an expression that checks if the items specified in the Jinja template are
null.

e.g.

{{ Dimension('listing__country_latest') }} = 'us'

->

{{ Dimension('listing__country_latest') }} IS NOT NULL

It's necessary to retain the Jinja template to better reproduce how the engine would have generated SQL
before this transformation.
"""
if filter_intersection is None:
return

for where_filter in filter_intersection.where_filters:
variable_expressions = NormalizeSqlRule._extract_variable_expression(where_filter.where_sql_template)

# Could be the case if the filter only contains a custom SQL expression.
if len(variable_expressions) == 0:
where_filter.where_sql_template = "TRUE"
continue

rewritten_expressions = tuple(f"({expression} IS NOT NULL)" for expression in variable_expressions)
where_filter.where_sql_template = " AND ".join(rewritten_expressions)

@override
def transform_model(self, semantic_manifest: PydanticSemanticManifest) -> PydanticSemanticManifest:
transformed_manifest = copy.deepcopy(semantic_manifest)

# Replace all element expressions in semantic models.
for semantic_model in transformed_manifest.semantic_models:
for measure in semantic_model.measures:
measure.expr = "1"
for entity in semantic_model.entities:
entity.expr = "1"
for dimension in semantic_model.dimensions:
dimension_type = dimension.type
if dimension_type is DimensionType.CATEGORICAL:
dimension.expr = "'1'"
elif dimension_type is DimensionType.TIME:
dimension.expr = "CAST('2020-01-01' AS TIMESTAMP)"
else:
assert_values_exhausted(dimension_type)

semantic_model.node_relation = copy.deepcopy(self._dummy_table_node_relation)

# Replace all metric filters and derived metric expressions.
for metric in transformed_manifest.metrics:
self._update_filter(metric.filter)

for input_metric in metric.type_params.metrics or ():
self._update_filter(input_metric.filter)

metric_type = metric.type
if metric_type is MetricType.SIMPLE:
metric.type_params.expr = "1"
elif metric_type is MetricType.DERIVED:
referenced_metric_names = [
(input_metric.alias or input_metric.name) for input_metric in metric.type_params.metrics or ()
]
metric.type_params.expr = " + ".join(referenced_metric_names)
elif metric_type is MetricType.RATIO:
numerator = metric.type_params.numerator
if numerator is not None:
self._update_filter(numerator.filter)
denominator = metric.type_params.denominator
if denominator is not None:
self._update_filter(denominator.filter)
elif (
metric_type is MetricType.SIMPLE
or metric_type is MetricType.CONVERSION
or metric_type is MetricType.CUMULATIVE
):
pass
else:
assert_values_exhausted(metric_type)

# Replace all filters in saved queries.
for saved_query in transformed_manifest.saved_queries:
self._update_filter(saved_query.query_params.where)

return transformed_manifest
52 changes: 52 additions & 0 deletions tests_metricflow/release_validation/test_external_manifest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import annotations

import logging
from pathlib import Path

import pytest
from _pytest.fixtures import FixtureRequest
from dbt_semantic_interfaces.implementations.semantic_manifest import PydanticSemanticManifest
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration
from metricflow_semantics.toolkit.mf_logging.lazy_formattable import LazyFormat

from metricflow.protocols.sql_client import SqlClient
from tests_metricflow.fixtures.sql_clients.ddl_sql_client import SqlClientWithDDLMethods
from tests_metricflow.release_validation.explain_results_snapshot import assert_explain_tester_results_equal
from tests_metricflow.release_validation.explain_runner import ExplainQueryStatus
from tests_metricflow.release_validation.explain_tester import DuckDbExplainTester
from tests_metricflow.release_validation.manifest_setup.external_manifest import ExternalManifestSetupSource
from tests_metricflow.release_validation.request_generation.saved_query import SavedQueryRequestGenerator
from tests_metricflow.table_snapshot.table_snapshots import SqlTableSnapshotRepository

logger = logging.getLogger(__name__)


@pytest.mark.slow
@pytest.mark.duckdb_only
def test_explain_all_saved_queries_from_external_manifest(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
ddl_sql_client: SqlClientWithDDLMethods,
sql_client: SqlClient,
simple_semantic_manifest: PydanticSemanticManifest,
source_table_snapshot_repository: SqlTableSnapshotRepository,
tmp_path: Path,
) -> None:
"""Test generated SQL for all saved queries in a JSON-serialized manifest."""
manifest_name = "simple_semantic_manifest"
semantic_manifest = simple_semantic_manifest
manifest_directory = tmp_path.joinpath("manifest_json")
manifest_path = DuckDbExplainTester.serialize_manifest_to_json_file(
manifest_name, semantic_manifest, manifest_directory
)
logger.debug(LazyFormat("Wrote manifest to JSON file", manifest_path=manifest_path))
result_file_directory = tmp_path.joinpath("results")
explain_tester = DuckDbExplainTester(
manifest_setup_source=ExternalManifestSetupSource(manifest_directory),
result_file_directory=result_file_directory,
request_generator=SavedQueryRequestGenerator(),
)

results = explain_tester.run()
assert all(result.status is ExplainQueryStatus.PASS for result in results)
assert_explain_tester_results_equal(request=request, mf_test_configuration=mf_test_configuration, results=results)
Loading