Skip to content

Commit e36746c

Browse files
Merge branch 'main' into use-sqlglot
2 parents a2f738d + 55321bd commit e36746c

File tree

8 files changed

+974
-195
lines changed

8 files changed

+974
-195
lines changed

src/preset_cli/api/clients/dbt.py

Lines changed: 164 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
_logger = logging.getLogger(__name__)
2525

2626
REST_ENDPOINT = URL("https://cloud.getdbt.com/")
27-
GRAPHQL_ENDPOINT = URL("https://metadata.cloud.getdbt.com/graphql")
27+
METADATA_GRAPHQL_ENDPOINT = URL("https://metadata.cloud.getdbt.com/graphql")
28+
SEMANTIC_LAYER_GRAPHQL_ENDPOINT = URL(
29+
"https://semantic-layer.cloud.getdbt.com/api/graphql",
30+
)
2831

2932

3033
class PostelSchema(Schema):
@@ -472,7 +475,7 @@ class TimeSchema(PostelSchema):
472475

473476
class StringOrSchema(fields.Field):
474477
"""
475-
Dynamic schema constructor for fields that could have a string or another schema
478+
Dynamic schema constructor for fields that could have a string or another schema.
476479
"""
477480

478481
def __init__(self, nested_schema, *args, **kwargs):
@@ -587,6 +590,50 @@ class MetricSchema(PostelSchema):
587590
expression = fields.String()
588591

589592

593+
class MFMetricType(str, Enum):
594+
"""
595+
Type of the MetricFlow metric.
596+
"""
597+
598+
SIMPLE = "SIMPLE"
599+
RATIO = "RATIO"
600+
CUMULATIVE = "CUMULATIVE"
601+
DERIVED = "DERIVED"
602+
603+
604+
class MFMetricSchema(PostelSchema):
605+
"""
606+
Schema for a MetricFlow metric.
607+
"""
608+
609+
name = fields.String()
610+
description = fields.String()
611+
type = PostelEnumField(MFMetricType)
612+
613+
614+
class MFSQLEngine(str, Enum):
615+
"""
616+
Databases supported by MetricFlow.
617+
"""
618+
619+
BIGQUERY = "BIGQUERY"
620+
DUCKDB = "DUCKDB"
621+
REDSHIFT = "REDSHIFT"
622+
POSTGRES = "POSTGRES"
623+
SNOWFLAKE = "SNOWFLAKE"
624+
DATABRICKS = "DATABRICKS"
625+
626+
627+
class MFMetricWithSQLSchema(MFMetricSchema):
628+
"""
629+
MetricFlow metric with dialect and SQL, as well as model.
630+
"""
631+
632+
sql = fields.String()
633+
dialect = PostelEnumField(MFSQLEngine)
634+
model = fields.String()
635+
636+
590637
class DataResponse(TypedDict):
591638
"""
592639
Type for the GraphQL response.
@@ -602,7 +649,10 @@ class DBTClient: # pylint: disable=too-few-public-methods
602649
"""
603650

604651
def __init__(self, auth: Auth):
605-
self.graphql_client = GraphqlClient(endpoint=GRAPHQL_ENDPOINT)
652+
self.metadata_graphql_client = GraphqlClient(endpoint=METADATA_GRAPHQL_ENDPOINT)
653+
self.semantic_layer_graphql_client = GraphqlClient(
654+
endpoint=SEMANTIC_LAYER_GRAPHQL_ENDPOINT,
655+
)
606656
self.baseurl = REST_ENDPOINT
607657

608658
self.session = auth.session
@@ -611,16 +661,6 @@ def __init__(self, auth: Auth):
611661
self.session.headers["X-Client-Version"] = __version__
612662
self.session.headers["X-dbt-partner-source"] = "preset"
613663

614-
def execute(self, query: str, **variables: Any) -> DataResponse:
615-
"""
616-
Run a GraphQL query.
617-
"""
618-
return self.graphql_client.execute(
619-
query=query,
620-
variables=variables,
621-
headers=self.session.headers,
622-
)
623-
624664
def get_accounts(self) -> List[AccountSchema]:
625665
"""
626666
List all accounts.
@@ -683,37 +723,46 @@ def get_models(self, job_id: int) -> List[ModelSchema]:
683723
Fetch all available models.
684724
"""
685725
query = """
686-
query ($jobId: Int!) {
687-
models(jobId: $jobId) {
688-
uniqueId
689-
dependsOn
690-
childrenL1
691-
name
692-
database
693-
schema
694-
description
695-
meta
696-
tags
697-
columns {
726+
query Models($jobId: BigInt!) {
727+
job(id: $jobId) {
728+
models {
729+
uniqueId
730+
dependsOn
731+
childrenL1
698732
name
733+
database
734+
schema
699735
description
736+
meta
737+
tags
738+
columns {
739+
name
740+
description
741+
type
742+
}
700743
}
701744
}
702745
}
703746
"""
704-
payload = self.execute(query, jobId=job_id)
747+
payload = self.metadata_graphql_client.execute(
748+
query=query,
749+
variables={"jobId": job_id},
750+
headers=self.session.headers,
751+
)
705752

706753
model_schema = ModelSchema()
707-
models = [model_schema.load(model) for model in payload["data"]["models"]]
754+
models = [
755+
model_schema.load(model) for model in payload["data"]["job"]["models"]
756+
]
708757

709758
return models
710759

711-
def get_og_metrics(self, job_id: int) -> List[Any]:
760+
def get_og_metrics(self, job_id: int) -> List[MetricSchema]:
712761
"""
713762
Fetch all available metrics.
714763
"""
715764
query = """
716-
query ($jobId: Int!) {
765+
query GetMetrics($jobId: Int!) {
717766
metrics(jobId: $jobId) {
718767
uniqueId
719768
name
@@ -731,13 +780,98 @@ def get_og_metrics(self, job_id: int) -> List[Any]:
731780
}
732781
}
733782
"""
734-
payload = self.execute(query, jobId=job_id)
783+
payload = self.metadata_graphql_client.execute(
784+
query=query,
785+
variables={"jobId": job_id},
786+
headers=self.session.headers,
787+
)
735788

736789
metric_schema = MetricSchema()
737790
metrics = [metric_schema.load(metric) for metric in payload["data"]["metrics"]]
738791

739792
return metrics
740793

794+
def get_sl_metrics(self, environment_id: int) -> List[MFMetricSchema]:
795+
"""
796+
Fetch all available metrics.
797+
"""
798+
query = """
799+
query GetMetrics($environmentId: BigInt!) {
800+
metrics(environmentId: $environmentId) {
801+
name
802+
description
803+
type
804+
}
805+
}
806+
"""
807+
payload = self.semantic_layer_graphql_client.execute(
808+
query=query,
809+
variables={"environmentId": environment_id},
810+
headers=self.session.headers,
811+
)
812+
813+
metric_schema = MFMetricSchema()
814+
metrics = [metric_schema.load(metric) for metric in payload["data"]["metrics"]]
815+
816+
return metrics
817+
818+
def get_sl_metric_sql(self, metric: str, environment_id: int) -> Optional[str]:
819+
"""
820+
Fetch metric SQL.
821+
822+
We fetch one metric at a time because if one metric fails to compile, the entire
823+
query fails.
824+
"""
825+
query = """
826+
mutation CompileSql($environmentId: BigInt!, $metricsInput: [MetricInput!]) {
827+
compileSql(
828+
environmentId: $environmentId
829+
metrics: $metricsInput
830+
groupBy: []
831+
) {
832+
sql
833+
}
834+
}
835+
"""
836+
payload = self.semantic_layer_graphql_client.execute(
837+
query=query,
838+
variables={
839+
"environmentId": environment_id,
840+
"metricsInput": [{"name": metric}],
841+
},
842+
headers=self.session.headers,
843+
)
844+
845+
if payload["data"] is None:
846+
errors = "\n\n".join(
847+
error["message"] for error in payload.get("errors", [])
848+
)
849+
_logger.warning("Unable to convert metric %s: %s", metric, errors)
850+
return None
851+
852+
return payload["data"]["compileSql"]["sql"]
853+
854+
def get_sl_dialect(self, environment_id: int) -> MFSQLEngine:
855+
"""
856+
Get the dialect used in the MetricFlow project.
857+
"""
858+
query = """
859+
query GetEnvironmentInfo($environmentId: BigInt!) {
860+
environmentInfo(environmentId: $environmentId) {
861+
dialect
862+
}
863+
}
864+
"""
865+
payload = self.semantic_layer_graphql_client.execute(
866+
query=query,
867+
variables={"environmentId": environment_id},
868+
headers=self.session.headers,
869+
)
870+
871+
return MFSQLEngine(payload["data"]["environmentInfo"]["dialect"])
872+
873+
# def get_sl_metric_sql(self,
874+
741875
def get_database_name(self, job_id: int) -> str:
742876
"""
743877
Return the database name.

src/preset_cli/cli/superset/sync/dbt/command.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,30 @@
66
import sys
77
import warnings
88
from pathlib import Path
9-
from typing import Optional, Tuple
9+
from typing import List, Optional, Tuple
1010

1111
import click
1212
import yaml
1313
from yarl import URL
1414

15-
from preset_cli.api.clients.dbt import DBTClient, JobSchema, MetricSchema, ModelSchema
15+
from preset_cli.api.clients.dbt import (
16+
DBTClient,
17+
JobSchema,
18+
MetricSchema,
19+
MFMetricWithSQLSchema,
20+
ModelSchema,
21+
)
1622
from preset_cli.api.clients.superset import SupersetClient
1723
from preset_cli.auth.token import TokenAuth
1824
from preset_cli.cli.superset.sync.dbt.databases import sync_database
1925
from preset_cli.cli.superset.sync.dbt.datasets import sync_datasets
2026
from preset_cli.cli.superset.sync.dbt.exposures import ModelKey, sync_exposures
2127
from preset_cli.cli.superset.sync.dbt.lib import apply_select
22-
from preset_cli.cli.superset.sync.dbt.metrics import get_superset_metrics_per_model
28+
from preset_cli.cli.superset.sync.dbt.metrics import (
29+
MultipleModelsError,
30+
get_model_from_sql,
31+
get_superset_metrics_per_model,
32+
)
2333
from preset_cli.exceptions import DatabaseNotFoundError
2434

2535

@@ -181,10 +191,7 @@ def dbt_core( # pylint: disable=too-many-arguments, too-many-branches, too-many
181191
config["columns"] = list(config["columns"].values())
182192
models.append(model_schema.load(config))
183193
models = apply_select(models, select, exclude)
184-
model_map = {
185-
ModelKey(model["schema"], model["name"]): f"ref('{model['name']}')"
186-
for model in models
187-
}
194+
model_map = {ModelKey(model["schema"], model["name"]): model for model in models}
188195

189196
if exposures_only:
190197
datasets = [
@@ -439,13 +446,37 @@ def dbt_cloud( # pylint: disable=too-many-arguments, too-many-locals
439446

440447
models = dbt_client.get_models(job["id"])
441448
models = apply_select(models, select, exclude)
442-
model_map = {
443-
ModelKey(model["schema"], model["name"]): f"ref('{model['name']}')"
444-
for model in models
445-
}
449+
model_map = {ModelKey(model["schema"], model["name"]): model for model in models}
446450

451+
# original dbt <= 1.6 metrics
447452
og_metrics = dbt_client.get_og_metrics(job["id"])
448-
superset_metrics = get_superset_metrics_per_model(og_metrics)
453+
454+
# MetricFlow metrics
455+
dialect = dbt_client.get_sl_dialect(job["environment_id"])
456+
mf_metric_schema = MFMetricWithSQLSchema()
457+
sl_metrics: List[MFMetricWithSQLSchema] = []
458+
for metric in dbt_client.get_sl_metrics(job["environment_id"]):
459+
sql = dbt_client.get_sl_metric_sql(metric["name"], job["environment_id"])
460+
if sql is not None:
461+
try:
462+
model = get_model_from_sql(sql, dialect, model_map)
463+
except MultipleModelsError:
464+
continue
465+
466+
sl_metrics.append(
467+
mf_metric_schema.load(
468+
{
469+
"name": metric["name"],
470+
"type": metric["type"],
471+
"description": metric["description"],
472+
"sql": sql,
473+
"dialect": dialect.value,
474+
"model": model["unique_id"],
475+
},
476+
),
477+
)
478+
479+
superset_metrics = get_superset_metrics_per_model(og_metrics, sl_metrics)
449480

450481
if exposures_only:
451482
datasets = [

0 commit comments

Comments
 (0)