|
11 | 11 | from collections import defaultdict |
12 | 12 | from typing import Dict, List, Optional, Set |
13 | 13 |
|
14 | | -import sqlparse |
15 | | -from sqlglot import Expression, parse_one |
| 14 | +import sqlglot |
| 15 | +from sqlglot import Expression, exp, parse_one |
16 | 16 | from sqlglot.expressions import Alias, Case, Identifier, If, Join, Select, Table, Where |
17 | 17 | from sqlglot.optimizer import traverse_scope |
18 | | -from sqlparse.sql import Identifier as SQLParseIdentifier |
19 | | -from sqlparse.sql import TokenList |
20 | 18 |
|
21 | 19 | from preset_cli.api.clients.dbt import ( |
22 | 20 | FilterSchema, |
|
43 | 41 |
|
44 | 42 | def get_metric_expression(unique_id: str, metrics: Dict[str, MetricSchema]) -> str: |
45 | 43 | """ |
46 | | - Return a SQL expression for a given dbt metric. |
| 44 | + Return a SQL expression for a given dbt metric using sqlglot. |
47 | 45 | """ |
48 | 46 | if unique_id not in metrics: |
49 | 47 | raise Exception(f"Invalid metric {unique_id}") |
@@ -77,18 +75,16 @@ def get_metric_expression(unique_id: str, metrics: Dict[str, MetricSchema]) -> s |
77 | 75 | return f"COUNT(DISTINCT {sql})" |
78 | 76 |
|
79 | 77 | if type_ in {"expression", "derived"}: |
80 | | - statement = sqlparse.parse(sql)[0] |
81 | | - tokens = statement.tokens[:] |
82 | | - while tokens: |
83 | | - token = tokens.pop(0) |
84 | | - |
85 | | - if isinstance(token, SQLParseIdentifier) and token.value in metrics: |
86 | | - parent_sql = get_metric_expression(token.value, metrics) |
87 | | - token.tokens = sqlparse.parse(parent_sql)[0].tokens |
88 | | - elif isinstance(token, TokenList): |
89 | | - tokens.extend(token.tokens) |
90 | | - |
91 | | - return str(statement) |
| 78 | + expression = sqlglot.parse_one(sql) |
| 79 | + tokens = expression.find_all(exp.Column) |
| 80 | + |
| 81 | + for token in tokens: |
| 82 | + if token.sql() in metrics: |
| 83 | + parent_sql = get_metric_expression(token.sql(), metrics) |
| 84 | + parent_expression = sqlglot.parse_one(parent_sql) |
| 85 | + token.replace(parent_expression) |
| 86 | + |
| 87 | + return expression.sql() |
92 | 88 |
|
93 | 89 | sorted_metric = dict(sorted(metric.items())) |
94 | 90 | raise Exception(f"Unable to generate metric expression from: {sorted_metric}") |
|
0 commit comments