Skip to content

Commit 9bdc49c

Browse files
Merge pull request #255 from preset-io/use-sqlglot
chore: use sqlglot
2 parents 55321bd + 8f8bbd0 commit 9bdc49c

File tree

6 files changed

+21
-28
lines changed

6 files changed

+21
-28
lines changed

dev-requirements.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#
2-
# This file is autogenerated by pip-compile with python 3.8
3-
# To update, run:
2+
# This file is autogenerated by pip-compile with Python 3.8
3+
# by the following command:
44
#
55
# pip-compile --no-annotate dev-requirements.in
66
#
@@ -68,7 +68,6 @@ six==1.16.0
6868
soupsieve==2.3.2.post1
6969
sqlalchemy==1.4.40
7070
sqlglot==20.7.1
71-
sqlparse==0.4.3
7271
tabulate==0.8.10
7372
toml==0.10.2
7473
tomli==2.0.1

requirements.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#
2-
# This file is autogenerated by pip-compile with python 3.8
3-
# To update, run:
2+
# This file is autogenerated by pip-compile with Python 3.8
3+
# by the following command:
44
#
55
# pip-compile --no-annotate
66
#
@@ -41,7 +41,6 @@ six==1.16.0
4141
soupsieve==2.3.2.post1
4242
sqlalchemy==1.4.35
4343
sqlglot==20.7.1
44-
sqlparse==0.4.3
4544
tabulate==0.8.9
4645
typing-extensions==4.2.0
4746
urllib3==1.26.9

setup.cfg

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ install_requires =
7272
rich>=12.3.0
7373
sqlalchemy>=1.4,<2
7474
sqlglot>=19
75-
sqlparse>=0.4.3
7675
tabulate>=0.8.9
7776
typing-extensions>=4.0.1
7877
yarl>=1.7.2

src/preset_cli/cli/superset/sql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
from prompt_toolkit.styles.pygments import style_from_pygments_cls
1616
from pygments.lexers.sql import SqlLexer
1717
from pygments.styles import get_style_by_name
18-
from sqlparse.keywords import KEYWORDS
18+
from sqlglot.tokens import Tokenizer
1919
from tabulate import tabulate
2020
from yarl import URL
2121

2222
from preset_cli.api.clients.superset import SupersetClient
2323
from preset_cli.exceptions import SupersetError
2424

25-
sql_completer = WordCompleter(list(KEYWORDS))
25+
sql_completer = WordCompleter(list(Tokenizer.KEYWORDS))
2626
style = style_from_pygments_cls(get_style_by_name("stata-dark"))
2727

2828

src/preset_cli/cli/superset/sync/dbt/metrics.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,10 @@
1111
from collections import defaultdict
1212
from typing import Dict, List, Optional, Set
1313

14-
import sqlparse
15-
from sqlglot import Expression, parse_one
14+
import sqlglot
15+
from sqlglot import Expression, exp, parse_one
1616
from sqlglot.expressions import Alias, Case, Identifier, If, Join, Select, Table, Where
1717
from sqlglot.optimizer import traverse_scope
18-
from sqlparse.sql import Identifier as SQLParseIdentifier
19-
from sqlparse.sql import TokenList
2018

2119
from preset_cli.api.clients.dbt import (
2220
FilterSchema,
@@ -43,7 +41,7 @@
4341

4442
def get_metric_expression(unique_id: str, metrics: Dict[str, MetricSchema]) -> str:
4543
"""
46-
Return a SQL expression for a given dbt metric.
44+
Return a SQL expression for a given dbt metric using sqlglot.
4745
"""
4846
if unique_id not in metrics:
4947
raise Exception(f"Invalid metric {unique_id}")
@@ -77,18 +75,16 @@ def get_metric_expression(unique_id: str, metrics: Dict[str, MetricSchema]) -> s
7775
return f"COUNT(DISTINCT {sql})"
7876

7977
if type_ in {"expression", "derived"}:
80-
statement = sqlparse.parse(sql)[0]
81-
tokens = statement.tokens[:]
82-
while tokens:
83-
token = tokens.pop(0)
84-
85-
if isinstance(token, SQLParseIdentifier) and token.value in metrics:
86-
parent_sql = get_metric_expression(token.value, metrics)
87-
token.tokens = sqlparse.parse(parent_sql)[0].tokens
88-
elif isinstance(token, TokenList):
89-
tokens.extend(token.tokens)
90-
91-
return str(statement)
78+
expression = sqlglot.parse_one(sql)
79+
tokens = expression.find_all(exp.Column)
80+
81+
for token in tokens:
82+
if token.sql() in metrics:
83+
parent_sql = get_metric_expression(token.sql(), metrics)
84+
parent_expression = sqlglot.parse_one(parent_sql)
85+
token.replace(parent_expression)
86+
87+
return expression.sql()
9288

9389
sorted_metric = dict(sorted(metric.items()))
9490
raise Exception(f"Unable to generate metric expression from: {sorted_metric}")

tests/cli/superset/sync/dbt/metrics_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def test_get_metric_expression() -> None:
8484
assert get_metric_expression("two", metrics) == "COUNT(DISTINCT user_id)"
8585

8686
assert get_metric_expression("three", metrics) == (
87-
"COUNT(CASE WHEN is_paying is true AND lifetime_value >= 100 AND "
88-
"company_name != 'Acme, Inc' AND signup_date >= '2020-01-01' THEN user_id END) "
87+
"COUNT(CASE WHEN is_paying IS TRUE AND lifetime_value >= 100 AND "
88+
"company_name <> 'Acme, Inc' AND signup_date >= '2020-01-01' THEN user_id END) "
8989
"- COUNT(DISTINCT user_id)"
9090
)
9191

0 commit comments

Comments
 (0)