Skip to content

Commit 947974f

Browse files
Merge pull request #259 from preset-io/get_models_from_sql
chore(dbt): refactor get_model_from_sql
2 parents efa5fd2 + 1631bff commit 947974f

File tree

4 files changed

+40
-38
lines changed

4 files changed

+40
-38
lines changed

src/preset_cli/cli/superset/sync/dbt/command.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@
2626
from preset_cli.cli.superset.sync.dbt.exposures import ModelKey, sync_exposures
2727
from preset_cli.cli.superset.sync.dbt.lib import apply_select
2828
from preset_cli.cli.superset.sync.dbt.metrics import (
29-
MultipleModelsError,
30-
get_model_from_sql,
29+
get_models_from_sql,
3130
get_superset_metrics_per_model,
3231
)
3332
from preset_cli.exceptions import DatabaseNotFoundError
@@ -351,10 +350,10 @@ def process_sl_metrics(
351350
if sql is None:
352351
continue
353352

354-
try:
355-
model = get_model_from_sql(sql, dialect, model_map)
356-
except MultipleModelsError:
353+
models = get_models_from_sql(sql, dialect, model_map)
354+
if len(models) > 1:
357355
continue
356+
model = models[0]
358357

359358
sl_metrics.append(
360359
mf_metric_schema.load(

src/preset_cli/cli/superset/sync/dbt/metrics.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -325,28 +325,19 @@ def convert_metric_flow_to_superset(
325325
}
326326

327327

328-
class MultipleModelsError(Exception):
329-
"""
330-
Raised when a metric depends on multiple models.
331-
"""
332-
333-
334-
def get_model_from_sql(
328+
def get_models_from_sql(
335329
sql: str,
336330
dialect: MFSQLEngine,
337331
model_map: Dict[ModelKey, ModelSchema],
338-
) -> ModelSchema:
332+
) -> List[ModelSchema]:
339333
"""
340334
Return the model associated with a SQL query.
341335
"""
342336
parsed_query = parse_one(sql, dialect=DIALECT_MAP.get(dialect))
343337
sources = list(parsed_query.find_all(Table))
344-
if len(sources) > 1:
345-
raise MultipleModelsError(
346-
f"Unable to convert metrics with multiple sources: {sql}",
347-
)
348338

349-
table = sources[0]
350-
key = ModelKey(table.db, table.name)
339+
for table in sources:
340+
if ModelKey(table.db, table.name) not in model_map:
341+
raise ValueError(f"Unable to find model for SQL source {table}")
351342

352-
return model_map[key]
343+
return [model_map[ModelKey(table.db, table.name)] for table in sources]

tests/cli/superset/sync/dbt/command_test.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,14 @@
149149
"schema": "public",
150150
"unique_id": "model.superset_examples.messages_channels",
151151
},
152+
{
153+
"database": "some_other_table",
154+
"description": "",
155+
"meta": {},
156+
"name": "some_other_table",
157+
"schema": "public",
158+
"unique_id": "model.superset_examples.some_other_table",
159+
},
152160
]
153161

154162
dbt_cloud_metrics = [
@@ -975,7 +983,7 @@ def test_dbt_cloud(mocker: MockerFixture) -> None:
975983
dbt_client.get_sl_metrics.return_value = dbt_metricflow_metrics
976984
dbt_client.get_sl_metric_sql.side_effect = [
977985
"SELECT COUNT(*) FROM public.messages_channels",
978-
"SELECT COUNT(*) FROM public.messages_channels JOIN some_other_table",
986+
"SELECT COUNT(*) FROM public.messages_channels JOIN public.some_other_table",
979987
None,
980988
]
981989
database = mocker.MagicMock()
@@ -1624,6 +1632,10 @@ def test_dbt_cloud_exposures_only(mocker: MockerFixture, fs: FakeFilesystem) ->
16241632
exposures,
16251633
[
16261634
{"schema": "public", "table_name": "messages_channels"},
1635+
{"schema": "public", "table_name": "some_other_table"},
16271636
],
1628-
{("public", "messages_channels"): dbt_cloud_models[0]},
1637+
{
1638+
("public", "messages_channels"): dbt_cloud_models[0],
1639+
("public", "some_other_table"): dbt_cloud_models[1],
1640+
},
16291641
)

tests/cli/superset/sync/dbt/metrics_test.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@
1212
from preset_cli.api.clients.dbt import MetricSchema, MFMetricWithSQLSchema, MFSQLEngine
1313
from preset_cli.cli.superset.sync.dbt.exposures import ModelKey
1414
from preset_cli.cli.superset.sync.dbt.metrics import (
15-
MultipleModelsError,
1615
convert_metric_flow_to_superset,
1716
convert_query_to_projection,
1817
get_metric_expression,
1918
get_metric_models,
2019
get_metrics_for_model,
21-
get_model_from_sql,
20+
get_models_from_sql,
2221
get_superset_metrics_per_model,
2322
)
2423

@@ -649,30 +648,31 @@ def test_convert_metric_flow_to_superset(mocker: MockerFixture) -> None:
649648
}
650649

651650

652-
def test_get_model_from_sql() -> None:
651+
def test_get_models_from_sql() -> None:
653652
"""
654-
Test the ``get_model_from_sql`` function.
653+
Test the ``get_models_from_sql`` function.
655654
"""
656655
model_map = {
657656
ModelKey("schema", "table"): {"name": "table"},
657+
ModelKey("schema", "a"): {"name": "a"},
658+
ModelKey("schema", "b"): {"name": "b"},
658659
}
659660

660-
assert get_model_from_sql(
661+
assert get_models_from_sql(
661662
"SELECT 1 FROM project.schema.table",
662663
MFSQLEngine.BIGQUERY,
663664
model_map, # type: ignore
664-
) == {"name": "table"}
665+
) == [{"name": "table"}]
665666

666-
with pytest.raises(MultipleModelsError) as excinfo:
667-
get_model_from_sql(
668-
"SELECT 1 FROM schema.a JOIN schema.b",
669-
MFSQLEngine.BIGQUERY,
670-
{},
671-
)
672-
assert (
673-
str(excinfo.value)
674-
== "Unable to convert metrics with multiple sources: SELECT 1 FROM schema.a JOIN schema.b"
675-
)
667+
assert get_models_from_sql(
668+
"SELECT 1 FROM schema.a JOIN schema.b",
669+
MFSQLEngine.BIGQUERY,
670+
model_map, # type: ignore
671+
) == [{"name": "a"}, {"name": "b"}]
672+
673+
with pytest.raises(ValueError) as excinfo:
674+
get_models_from_sql("SELECT 1 FROM schema.c", MFSQLEngine.BIGQUERY, {})
675+
assert str(excinfo.value) == "Unable to find model for SQL source schema.c"
676676

677677

678678
def test_get_superset_metrics_per_model() -> None:

0 commit comments

Comments
 (0)