Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 3 additions & 91 deletions sqlglot/dialects/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,14 @@
from copy import deepcopy
from collections import defaultdict

from sqlglot import exp, transforms
from sqlglot.dialects.dialect import (
date_delta_sql,
timestamptrunc_sql,
groupconcat_sql,
)
from sqlglot import exp
from sqlglot.dialects.spark import Spark
from sqlglot.generators.databricks import DatabricksGenerator
from sqlglot.parsers.databricks import DatabricksParser
from sqlglot.tokens import TokenType
from sqlglot.optimizer.annotate_types import TypeAnnotator


def _jsonextract_sql(
self: Databricks.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar
) -> str:
this = self.sql(expression, "this")
expr = self.sql(expression, "expression")
return f"{this}:{expr}"


class Databricks(Spark):
SAFE_DIVISION = False
COPY_PARAMS_ARE_CSV = False
Expand All @@ -49,80 +37,4 @@ class Tokenizer(Spark.Tokenizer):

Parser = DatabricksParser

class Generator(Spark.Generator):
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
COPY_PARAMS_ARE_WRAPPED = False
COPY_PARAMS_EQ_REQUIRED = True
JSON_PATH_SINGLE_QUOTE_ESCAPE = False
SAFE_JSON_PATH_KEY_RE = exp.SAFE_IDENTIFIER_RE
QUOTE_JSON_PATH = False
PARSE_JSON_NAME = "PARSE_JSON"

TRANSFORMS = {
**Spark.Generator.TRANSFORMS,
exp.CurrentVersion: lambda *_: "CURRENT_VERSION()",
exp.DateAdd: date_delta_sql("DATEADD"),
exp.DateDiff: date_delta_sql("DATEDIFF"),
exp.DatetimeAdd: lambda self, e: self.func(
"TIMESTAMPADD", e.unit, e.expression, e.this
),
exp.DatetimeSub: lambda self, e: self.func(
"TIMESTAMPADD",
e.unit,
exp.Mul(this=e.expression, expression=exp.Literal.number(-1)),
e.this,
),
exp.DatetimeTrunc: timestamptrunc_sql(),
exp.GroupConcat: groupconcat_sql,
exp.Select: transforms.preprocess(
[
transforms.eliminate_distinct_on,
transforms.unnest_to_explode,
transforms.any_to_exists,
]
),
exp.JSONExtract: _jsonextract_sql,
exp.JSONExtractScalar: _jsonextract_sql,
exp.JSONPathRoot: lambda *_: "",
exp.ToChar: lambda self, e: (
self.cast_sql(exp.Cast(this=e.this, to=exp.DataType(this="STRING")))
if e.args.get("is_numeric")
else self.function_fallback_sql(e)
),
exp.CurrentCatalog: lambda *_: "CURRENT_CATALOG()",
}

TRANSFORMS.pop(exp.RegexpLike)
TRANSFORMS.pop(exp.TryCast)

TYPE_MAPPING = {
**Spark.Generator.TYPE_MAPPING,
exp.DType.NULL: "VOID",
}

def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
constraint = expression.find(exp.GeneratedAsIdentityColumnConstraint)
kind = expression.kind
if (
constraint
and isinstance(kind, exp.DataType)
and kind.this in exp.DataType.INTEGER_TYPES
):
# only BIGINT generated identity constraints are supported
expression.set("kind", exp.DataType.build("bigint"))

return super().columndef_sql(expression, sep)

def jsonpath_sql(self, expression: exp.JSONPath) -> str:
expression.set("escape", None)
return super().jsonpath_sql(expression)

def uniform_sql(self, expression: exp.Uniform) -> str:
gen = expression.args.get("gen")
seed = expression.args.get("seed")

# From Snowflake UNIFORM(min, max, gen) as RANDOM(), RANDOM(seed), or constant value -> Extract seed
if gen:
seed = gen.this

return self.func("UNIFORM", expression.this, expression.expression, seed)
Generator = DatabricksGenerator
198 changes: 4 additions & 194 deletions sqlglot/dialects/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,71 +2,11 @@

import typing as t

from sqlglot import exp
from sqlglot.dialects.dialect import (
array_append_sql,
rename_func,
unit_to_var,
timestampdiff_sql,
date_delta_to_binary_interval_op,
groupconcat_sql,
)
from sqlglot.dialects.spark2 import Spark2
from sqlglot.generators.spark import SparkGenerator
from sqlglot.parsers.spark import SparkParser
from sqlglot import generator
from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider
from sqlglot.typing.spark import EXPRESSION_METADATA
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
from sqlglot.transforms import (
ctas_with_tmp_tables_to_create_tmp_view,
remove_unique_constraints,
preprocess,
move_partitioned_by_to_schema_columns,
)


def _normalize_partition(e: exp.Expr) -> exp.Expr:
"""Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)"""
if isinstance(e, str):
return exp.to_identifier(e)
if isinstance(e, exp.Literal):
return exp.to_identifier(e.name)
return e


def _dateadd_sql(self: Spark.Generator, expression: exp.TsOrDsAdd | exp.TimestampAdd) -> str:
if not expression.unit or (
isinstance(expression, exp.TsOrDsAdd) and expression.text("unit").upper() == "DAY"
):
# Coming from Hive/Spark2 DATE_ADD or roundtripping the 2-arg version of Spark3/DB
return self.func("DATE_ADD", expression.this, expression.expression)

this = self.func(
"DATE_ADD",
unit_to_var(expression),
expression.expression,
expression.this,
)

if isinstance(expression, exp.TsOrDsAdd):
# The 3 arg version of DATE_ADD produces a timestamp in Spark3/DB but possibly not
# in other dialects
return_type = expression.return_type
if not return_type.is_type(exp.DType.TIMESTAMP, exp.DType.DATETIME):
this = f"CAST({this} AS {return_type})"

return this


def _groupconcat_sql(self: Spark.Generator, expression: exp.GroupConcat) -> str:
if self.dialect.version < (4,):
expr = exp.ArrayToString(
this=exp.ArrayAgg(this=expression.this),
expression=expression.args.get("separator") or exp.Literal.string(""),
)
return self.sql(expr)

return groupconcat_sql(self, expression)
from sqlglot.typing.spark import EXPRESSION_METADATA


class Spark(Spark2):
Expand All @@ -91,134 +31,4 @@ class Tokenizer(Spark2.Tokenizer):

Parser = SparkParser

class Generator(Spark2.Generator):
SUPPORTS_TO_NUMBER = True
PAD_FILL_PATTERN_IS_REQUIRED = False
SUPPORTS_CONVERT_TIMEZONE = True
SUPPORTS_MEDIAN = True
SUPPORTS_UNIX_SECONDS = True
SUPPORTS_DECODE_CASE = True
SET_ASSIGNMENT_REQUIRES_VARIABLE_KEYWORD = True

TYPE_MAPPING = {
**Spark2.Generator.TYPE_MAPPING,
exp.DType.MONEY: "DECIMAL(15, 4)",
exp.DType.SMALLMONEY: "DECIMAL(6, 4)",
exp.DType.UUID: "STRING",
exp.DType.TIMESTAMPLTZ: "TIMESTAMP_LTZ",
exp.DType.TIMESTAMPNTZ: "TIMESTAMP_NTZ",
}

TRANSFORMS = {
**Spark2.Generator.TRANSFORMS,
exp.ArrayConstructCompact: lambda self, e: self.func(
"ARRAY_COMPACT", self.func("ARRAY", *e.expressions)
),
exp.ArrayInsert: lambda self, e: self.func(
"ARRAY_INSERT", e.this, e.args.get("position"), e.expression
),
exp.ArrayAppend: array_append_sql("ARRAY_APPEND"),
exp.ArrayPrepend: array_append_sql("ARRAY_PREPEND"),
exp.BitwiseAndAgg: rename_func("BIT_AND"),
exp.BitwiseOrAgg: rename_func("BIT_OR"),
exp.BitwiseXorAgg: rename_func("BIT_XOR"),
exp.BitwiseCount: rename_func("BIT_COUNT"),
exp.Create: preprocess(
[
remove_unique_constraints,
lambda e: ctas_with_tmp_tables_to_create_tmp_view(
e, temporary_storage_provider
),
move_partitioned_by_to_schema_columns,
]
),
exp.CurrentVersion: rename_func("VERSION"),
exp.DateFromUnixDate: rename_func("DATE_FROM_UNIX_DATE"),
exp.DatetimeAdd: date_delta_to_binary_interval_op(cast=False),
exp.DatetimeSub: date_delta_to_binary_interval_op(cast=False),
exp.GroupConcat: _groupconcat_sql,
exp.EndsWith: rename_func("ENDSWITH"),
exp.JSONKeys: rename_func("JSON_OBJECT_KEYS"),
exp.PartitionedByProperty: lambda self, e: (
f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}"
),
exp.SafeAdd: rename_func("TRY_ADD"),
exp.SafeMultiply: rename_func("TRY_MULTIPLY"),
exp.SafeSubtract: rename_func("TRY_SUBTRACT"),
exp.StartsWith: rename_func("STARTSWITH"),
exp.TimeAdd: date_delta_to_binary_interval_op(cast=False),
exp.TimeSub: date_delta_to_binary_interval_op(cast=False),
exp.TsOrDsAdd: _dateadd_sql,
exp.TimestampAdd: _dateadd_sql,
exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"),
exp.TimestampSub: date_delta_to_binary_interval_op(cast=False),
exp.DatetimeDiff: timestampdiff_sql,
exp.TimestampDiff: timestampdiff_sql,
exp.TryCast: lambda self, e: (
self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
),
}
TRANSFORMS.pop(exp.AnyValue)
TRANSFORMS.pop(exp.DateDiff)
TRANSFORMS.pop(exp.With)

def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str:
return generator.Generator.ignorenulls_sql(self, expression)

def bracket_sql(self, expression: exp.Bracket) -> str:
if expression.args.get("safe"):
key = seq_get(self.bracket_offset_expressions(expression, index_offset=1), 0)
return self.func("TRY_ELEMENT_AT", expression.this, key)

return super().bracket_sql(expression)

def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str:
return f"GENERATED ALWAYS AS ({self.sql(expression, 'this')})"

def anyvalue_sql(self, expression: exp.AnyValue) -> str:
return self.function_fallback_sql(expression)

def datediff_sql(self, expression: exp.DateDiff) -> str:
end = self.sql(expression, "this")
start = self.sql(expression, "expression")

if expression.unit:
return self.func("DATEDIFF", unit_to_var(expression), start, end)

return self.func("DATEDIFF", end, start)

def placeholder_sql(self, expression: exp.Placeholder) -> str:
if not expression.args.get("widget"):
return super().placeholder_sql(expression)

return f"{{{expression.name}}}"

def readparquet_sql(self, expression: exp.ReadParquet) -> str:
if len(expression.expressions) != 1:
self.unsupported("READ_PARQUET with multiple arguments is not supported")
return ""

parquet_file = expression.expressions[0]
return f"parquet.`{parquet_file.name}`"

def ifblock_sql(self, expression: exp.IfBlock) -> str:
condition = expression.this
true_block = expression.args.get("true")

condition_expr = None
if isinstance(condition, exp.Not):
inner = condition.this
if isinstance(inner, exp.Is) and isinstance(inner.expression, exp.Null):
condition_expr = inner.this

if isinstance(condition_expr, exp.ObjectId):
object_type = condition_expr.expression
if (
(object_type is None or object_type.name.upper() == "U")
and isinstance(true_block, exp.Block)
and isinstance(drop := true_block.expressions[0], exp.Drop)
):
drop.set("exists", True)
return self.sql(drop)

return super().ifblock_sql(expression)
Generator = SparkGenerator
Loading
Loading