2424_logger = logging .getLogger (__name__ )
2525
2626REST_ENDPOINT = URL ("https://cloud.getdbt.com/" )
27- GRAPHQL_ENDPOINT = URL ("https://metadata.cloud.getdbt.com/graphql" )
27+ METADATA_GRAPHQL_ENDPOINT = URL ("https://metadata.cloud.getdbt.com/graphql" )
28+ SEMANTIC_LAYER_GRAPHQL_ENDPOINT = URL (
29+ "https://semantic-layer.cloud.getdbt.com/api/graphql" ,
30+ )
2831
2932
3033class PostelSchema (Schema ):
@@ -472,7 +475,7 @@ class TimeSchema(PostelSchema):
472475
473476class StringOrSchema (fields .Field ):
474477 """
475- Dynamic schema constructor for fields that could have a string or another schema
478+ Dynamic schema constructor for fields that could have a string or another schema.
476479 """
477480
478481 def __init__ (self , nested_schema , * args , ** kwargs ):
@@ -587,6 +590,50 @@ class MetricSchema(PostelSchema):
587590 expression = fields .String ()
588591
589592
593+ class MFMetricType (str , Enum ):
594+ """
595+ Type of the MetricFlow metric.
596+ """
597+
598+ SIMPLE = "SIMPLE"
599+ RATIO = "RATIO"
600+ CUMULATIVE = "CUMULATIVE"
601+ DERIVED = "DERIVED"
602+
603+
604+ class MFMetricSchema (PostelSchema ):
605+ """
606+ Schema for a MetricFlow metric.
607+ """
608+
609+ name = fields .String ()
610+ description = fields .String ()
611+ type = PostelEnumField (MFMetricType )
612+
613+
614+ class MFSQLEngine (str , Enum ):
615+ """
616+ Databases supported by MetricFlow.
617+ """
618+
619+ BIGQUERY = "BIGQUERY"
620+ DUCKDB = "DUCKDB"
621+ REDSHIFT = "REDSHIFT"
622+ POSTGRES = "POSTGRES"
623+ SNOWFLAKE = "SNOWFLAKE"
624+ DATABRICKS = "DATABRICKS"
625+
626+
627+ class MFMetricWithSQLSchema (MFMetricSchema ):
628+ """
629+ MetricFlow metric with dialect and SQL, as well as model.
630+ """
631+
632+ sql = fields .String ()
633+ dialect = PostelEnumField (MFSQLEngine )
634+ model = fields .String ()
635+
636+
590637class DataResponse (TypedDict ):
591638 """
592639 Type for the GraphQL response.
@@ -602,7 +649,10 @@ class DBTClient: # pylint: disable=too-few-public-methods
602649 """
603650
604651 def __init__ (self , auth : Auth ):
605- self .graphql_client = GraphqlClient (endpoint = GRAPHQL_ENDPOINT )
652+ self .metadata_graphql_client = GraphqlClient (endpoint = METADATA_GRAPHQL_ENDPOINT )
653+ self .semantic_layer_graphql_client = GraphqlClient (
654+ endpoint = SEMANTIC_LAYER_GRAPHQL_ENDPOINT ,
655+ )
606656 self .baseurl = REST_ENDPOINT
607657
608658 self .session = auth .session
@@ -611,16 +661,6 @@ def __init__(self, auth: Auth):
611661 self .session .headers ["X-Client-Version" ] = __version__
612662 self .session .headers ["X-dbt-partner-source" ] = "preset"
613663
614- def execute (self , query : str , ** variables : Any ) -> DataResponse :
615- """
616- Run a GraphQL query.
617- """
618- return self .graphql_client .execute (
619- query = query ,
620- variables = variables ,
621- headers = self .session .headers ,
622- )
623-
624664 def get_accounts (self ) -> List [AccountSchema ]:
625665 """
626666 List all accounts.
@@ -683,37 +723,46 @@ def get_models(self, job_id: int) -> List[ModelSchema]:
683723 Fetch all available models.
684724 """
685725 query = """
686- query ($jobId: Int!) {
687- models(jobId: $jobId) {
688- uniqueId
689- dependsOn
690- childrenL1
691- name
692- database
693- schema
694- description
695- meta
696- tags
697- columns {
726+ query Models($jobId: BigInt!) {
727+ job(id: $jobId) {
728+ models {
729+ uniqueId
730+ dependsOn
731+ childrenL1
698732 name
733+ database
734+ schema
699735 description
736+ meta
737+ tags
738+ columns {
739+ name
740+ description
741+ type
742+ }
700743 }
701744 }
702745 }
703746 """
704- payload = self .execute (query , jobId = job_id )
747+ payload = self .metadata_graphql_client .execute (
748+ query = query ,
749+ variables = {"jobId" : job_id },
750+ headers = self .session .headers ,
751+ )
705752
706753 model_schema = ModelSchema ()
707- models = [model_schema .load (model ) for model in payload ["data" ]["models" ]]
754+ models = [
755+ model_schema .load (model ) for model in payload ["data" ]["job" ]["models" ]
756+ ]
708757
709758 return models
710759
711- def get_og_metrics (self , job_id : int ) -> List [Any ]:
760+ def get_og_metrics (self , job_id : int ) -> List [MetricSchema ]:
712761 """
713762 Fetch all available metrics.
714763 """
715764 query = """
716- query ($jobId: Int!) {
765+ query GetMetrics ($jobId: Int!) {
717766 metrics(jobId: $jobId) {
718767 uniqueId
719768 name
@@ -731,13 +780,98 @@ def get_og_metrics(self, job_id: int) -> List[Any]:
731780 }
732781 }
733782 """
734- payload = self .execute (query , jobId = job_id )
783+ payload = self .metadata_graphql_client .execute (
784+ query = query ,
785+ variables = {"jobId" : job_id },
786+ headers = self .session .headers ,
787+ )
735788
736789 metric_schema = MetricSchema ()
737790 metrics = [metric_schema .load (metric ) for metric in payload ["data" ]["metrics" ]]
738791
739792 return metrics
740793
794+ def get_sl_metrics (self , environment_id : int ) -> List [MFMetricSchema ]:
795+ """
796+ Fetch all available metrics.
797+ """
798+ query = """
799+ query GetMetrics($environmentId: BigInt!) {
800+ metrics(environmentId: $environmentId) {
801+ name
802+ description
803+ type
804+ }
805+ }
806+ """
807+ payload = self .semantic_layer_graphql_client .execute (
808+ query = query ,
809+ variables = {"environmentId" : environment_id },
810+ headers = self .session .headers ,
811+ )
812+
813+ metric_schema = MFMetricSchema ()
814+ metrics = [metric_schema .load (metric ) for metric in payload ["data" ]["metrics" ]]
815+
816+ return metrics
817+
818+ def get_sl_metric_sql (self , metric : str , environment_id : int ) -> Optional [str ]:
819+ """
820+ Fetch metric SQL.
821+
822+ We fetch one metric at a time because if one metric fails to compile, the entire
823+ query fails.
824+ """
825+ query = """
826+ mutation CompileSql($environmentId: BigInt!, $metricsInput: [MetricInput!]) {
827+ compileSql(
828+ environmentId: $environmentId
829+ metrics: $metricsInput
830+ groupBy: []
831+ ) {
832+ sql
833+ }
834+ }
835+ """
836+ payload = self .semantic_layer_graphql_client .execute (
837+ query = query ,
838+ variables = {
839+ "environmentId" : environment_id ,
840+ "metricsInput" : [{"name" : metric }],
841+ },
842+ headers = self .session .headers ,
843+ )
844+
845+ if payload ["data" ] is None :
846+ errors = "\n \n " .join (
847+ error ["message" ] for error in payload .get ("errors" , [])
848+ )
849+ _logger .warning ("Unable to convert metric %s: %s" , metric , errors )
850+ return None
851+
852+ return payload ["data" ]["compileSql" ]["sql" ]
853+
854+ def get_sl_dialect (self , environment_id : int ) -> MFSQLEngine :
855+ """
856+ Get the dialect used in the MetricFlow project.
857+ """
858+ query = """
859+ query GetEnvironmentInfo($environmentId: BigInt!) {
860+ environmentInfo(environmentId: $environmentId) {
861+ dialect
862+ }
863+ }
864+ """
865+ payload = self .semantic_layer_graphql_client .execute (
866+ query = query ,
867+ variables = {"environmentId" : environment_id },
868+ headers = self .session .headers ,
869+ )
870+
871+ return MFSQLEngine (payload ["data" ]["environmentInfo" ]["dialect" ])
872+
873+ # def get_sl_metric_sql(self,
874+
741875 def get_database_name (self , job_id : int ) -> str :
742876 """
743877 Return the database name.
0 commit comments