Skip to content

Commit 93d8802

Browse files
authored
Merge branch 'main' into tswast-bbq-obj
2 parents e89ccfb + e1d54d2 commit 93d8802

File tree

20 files changed

+231
-104
lines changed

20 files changed

+231
-104
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from bigframes.core import window_spec
2424
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
2525
from bigframes.core.compile.sqlglot.aggregations.windows import apply_window_if_present
26+
from bigframes.core.compile.sqlglot.expressions import constants
2627
import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr
2728
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
2829
from bigframes.operations import aggregations as agg_ops
@@ -44,9 +45,13 @@ def _(
4445
column: typed_expr.TypedExpr,
4546
window: typing.Optional[window_spec.WindowSpec] = None,
4647
) -> sge.Expression:
47-
# BQ will return null for empty column, result would be false in pandas.
48-
result = apply_window_if_present(sge.func("LOGICAL_AND", column.expr), window)
49-
return sge.func("IFNULL", result, sge.true())
48+
expr = column.expr
49+
if column.dtype != dtypes.BOOL_DTYPE:
50+
expr = sge.NEQ(this=expr, expression=sge.convert(0))
51+
expr = apply_window_if_present(sge.func("LOGICAL_AND", expr), window)
52+
53+
# BQ will return null for empty column, result would be true in pandas.
54+
return sge.func("COALESCE", expr, sge.convert(True))
5055

5156

5257
@UNARY_OP_REGISTRATION.register(agg_ops.AnyOp)
@@ -56,6 +61,8 @@ def _(
5661
window: typing.Optional[window_spec.WindowSpec] = None,
5762
) -> sge.Expression:
5863
expr = column.expr
64+
if column.dtype != dtypes.BOOL_DTYPE:
65+
expr = sge.NEQ(this=expr, expression=sge.convert(0))
5966
expr = apply_window_if_present(sge.func("LOGICAL_OR", expr), window)
6067

6168
# BQ will return null for empty column, result would be false in pandas.
@@ -326,6 +333,15 @@ def _(
326333
unit=sge.Identifier(this="MICROSECOND"),
327334
)
328335

336+
if column.dtype == dtypes.DATE_DTYPE:
337+
date_diff = sge.DateDiff(
338+
this=column.expr, expression=shifted, unit=sge.Identifier(this="DAY")
339+
)
340+
return sge.Cast(
341+
this=sge.Floor(this=date_diff * constants._DAY_TO_MICROSECONDS),
342+
to="INT64",
343+
)
344+
329345
raise TypeError(f"Cannot perform diff on type {column.dtype}")
330346

331347

@@ -410,24 +426,28 @@ def _(
410426
column: typed_expr.TypedExpr,
411427
window: typing.Optional[window_spec.WindowSpec] = None,
412428
) -> sge.Expression:
429+
expr = column.expr
430+
if column.dtype == dtypes.BOOL_DTYPE:
431+
expr = sge.Cast(this=expr, to="INT64")
432+
413433
# Need to short-circuit as log with zeroes is illegal sql
414-
is_zero = sge.EQ(this=column.expr, expression=sge.convert(0))
434+
is_zero = sge.EQ(this=expr, expression=sge.convert(0))
415435

416436
# There is no product sql aggregate function, so must implement as a sum of logs, and then
417437
# apply power after. Note, log and power base must be equal! This impl uses natural log.
418-
logs = (
419-
sge.Case()
420-
.when(is_zero, sge.convert(0))
421-
.else_(sge.func("LN", sge.func("ABS", column.expr)))
438+
logs = sge.If(
439+
this=is_zero,
440+
true=sge.convert(0),
441+
false=sge.func("LOG", sge.convert(2), sge.func("ABS", expr)),
422442
)
423443
logs_sum = apply_window_if_present(sge.func("SUM", logs), window)
424-
magnitude = sge.func("EXP", logs_sum)
444+
magnitude = sge.func("POWER", sge.convert(2), logs_sum)
425445

426446
# Can't determine sign from logs, so have to determine parity of count of negative inputs
427447
is_negative = (
428448
sge.Case()
429449
.when(
430-
sge.LT(this=sge.func("SIGN", column.expr), expression=sge.convert(0)),
450+
sge.EQ(this=sge.func("SIGN", expr), expression=sge.convert(-1)),
431451
sge.convert(1),
432452
)
433453
.else_(sge.convert(0))
@@ -445,11 +465,7 @@ def _(
445465
.else_(
446466
sge.Mul(
447467
this=magnitude,
448-
expression=sge.If(
449-
this=sge.EQ(this=negative_count_parity, expression=sge.convert(1)),
450-
true=sge.convert(-1),
451-
false=sge.convert(1),
452-
),
468+
expression=sge.func("POWER", sge.convert(-1), negative_count_parity),
453469
)
454470
)
455471
)
@@ -499,14 +515,18 @@ def _(
499515
column: typed_expr.TypedExpr,
500516
window: typing.Optional[window_spec.WindowSpec] = None,
501517
) -> sge.Expression:
502-
# TODO: Support interpolation argument
503-
# TODO: Support percentile_disc
504-
result: sge.Expression = sge.func("PERCENTILE_CONT", column.expr, sge.convert(op.q))
518+
expr = column.expr
519+
if column.dtype == dtypes.BOOL_DTYPE:
520+
expr = sge.Cast(this=expr, to="INT64")
521+
522+
result: sge.Expression = sge.func("PERCENTILE_CONT", expr, sge.convert(op.q))
505523
if window is None:
506-
# PERCENTILE_CONT is a navigation function, not an aggregate function, so it always needs an OVER clause.
524+
# PERCENTILE_CONT is a navigation function, not an aggregate function,
525+
# so it always needs an OVER clause.
507526
result = sge.Window(this=result)
508527
else:
509528
result = apply_window_if_present(result, window)
529+
510530
if op.should_floor_result:
511531
result = sge.Cast(this=sge.func("FLOOR", result), to="INT64")
512532
return result

bigframes/core/compile/sqlglot/expressions/comparison_ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616

1717
import typing
1818

19+
import bigframes_vendored.sqlglot as sg
1920
import bigframes_vendored.sqlglot.expressions as sge
2021
import pandas as pd
2122

2223
from bigframes import dtypes
2324
from bigframes import operations as ops
25+
from bigframes.core.compile.sqlglot import sqlglot_ir
2426
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2527
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2628

@@ -62,6 +64,10 @@ def _(expr: TypedExpr, op: ops.IsInOp) -> sge.Expression:
6264

6365
@register_binary_op(ops.eq_op)
6466
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
67+
if sqlglot_ir._is_null_literal(left.expr):
68+
return sge.Is(this=right.expr, expression=sge.Null())
69+
if sqlglot_ir._is_null_literal(right.expr):
70+
return sge.Is(this=left.expr, expression=sge.Null())
6571
left_expr = _coerce_bool_to_int(left)
6672
right_expr = _coerce_bool_to_int(right)
6773
return sge.EQ(this=left_expr, expression=right_expr)
@@ -139,6 +145,17 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
139145

140146
@register_binary_op(ops.ne_op)
141147
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
148+
if sqlglot_ir._is_null_literal(left.expr):
149+
return sge.Is(
150+
this=sge.paren(right.expr, copy=False),
151+
expression=sg.not_(sge.Null(), copy=False),
152+
)
153+
if sqlglot_ir._is_null_literal(right.expr):
154+
return sge.Is(
155+
this=sge.paren(left.expr, copy=False),
156+
expression=sg.not_(sge.Null(), copy=False),
157+
)
158+
142159
left_expr = _coerce_bool_to_int(left)
143160
right_expr = _coerce_bool_to_int(right)
144161
return sge.NEQ(this=left_expr, expression=right_expr)

bigframes/core/compile/sqlglot/expressions/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
_NAN = sge.Cast(this=sge.convert("NaN"), to="FLOAT64")
2121
_INF = sge.Cast(this=sge.convert("Infinity"), to="FLOAT64")
2222
_NEG_INF = sge.Cast(this=sge.convert("-Infinity"), to="FLOAT64")
23+
_DAY_TO_MICROSECONDS = sge.convert(86400000000)
2324

2425
# Approx Highest number you can pass in to EXP function and get a valid FLOAT64 result
2526
# FLOAT64 has 11 exponent bits, so max values is about 2**(2**10)

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from bigframes import dtypes
2121
from bigframes import operations as ops
22-
from bigframes.core.compile.sqlglot import sqlglot_types
22+
from bigframes.core.compile.sqlglot import sqlglot_ir, sqlglot_types
2323
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2424
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2525

@@ -101,11 +101,23 @@ def _(expr: TypedExpr) -> sge.Expression:
101101
def _(expr: TypedExpr, op: ops.MapOp) -> sge.Expression:
102102
if len(op.mappings) == 0:
103103
return expr.expr
104+
105+
mappings = [
106+
(
107+
sqlglot_ir._literal(key, dtypes.is_compatible(key, expr.dtype)),
108+
sqlglot_ir._literal(value, dtypes.is_compatible(value, expr.dtype)),
109+
)
110+
for key, value in op.mappings
111+
]
104112
return sge.Case(
105-
this=expr.expr,
106113
ifs=[
107-
sge.If(this=sge.convert(key), true=sge.convert(value))
108-
for key, value in op.mappings
114+
sge.If(
115+
this=sge.EQ(this=expr.expr, expression=key)
116+
if not sqlglot_ir._is_null_literal(key)
117+
else sge.Is(this=expr.expr, expression=sge.Null()),
118+
true=value,
119+
)
120+
for key, value in mappings
109121
],
110122
default=expr.expr,
111123
)

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,15 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:
642642
return new_select_expr
643643

644644

645+
def _is_null_literal(expr: sge.Expression) -> bool:
646+
"""Checks if the given expression is a NULL literal."""
647+
if isinstance(expr, sge.Null):
648+
return True
649+
if isinstance(expr, sge.Cast) and isinstance(expr.this, sge.Null):
650+
return True
651+
return False
652+
653+
645654
def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
646655
sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None
647656
if sqlglot_type is None:
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
WITH `bfcte_0` AS (
22
SELECT
3-
`bool_col`
3+
`bool_col`,
4+
`int64_col`
45
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
56
), `bfcte_1` AS (
67
SELECT
7-
COALESCE(LOGICAL_AND(`bool_col`), TRUE) AS `bfcol_1`
8+
COALESCE(LOGICAL_AND(`bool_col`), TRUE) AS `bfcol_2`,
9+
COALESCE(LOGICAL_AND(`int64_col` <> 0), TRUE) AS `bfcol_3`
810
FROM `bfcte_0`
911
)
1012
SELECT
11-
`bfcol_1` AS `bool_col`
13+
`bfcol_2` AS `bool_col`,
14+
`bfcol_3` AS `int64_col`
1215
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all/window_partition_out.sql

Lines changed: 0 additions & 14 deletions
This file was deleted.

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all/window_out.sql renamed to tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_all_w_window/out.sql

File renamed without changes.
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
WITH `bfcte_0` AS (
22
SELECT
3-
`bool_col`
3+
`bool_col`,
4+
`int64_col`
45
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
56
), `bfcte_1` AS (
67
SELECT
7-
COALESCE(LOGICAL_OR(`bool_col`), FALSE) AS `bfcol_1`
8+
COALESCE(LOGICAL_OR(`bool_col`), FALSE) AS `bfcol_2`,
9+
COALESCE(LOGICAL_OR(`int64_col` <> 0), FALSE) AS `bfcol_3`
810
FROM `bfcte_0`
911
)
1012
SELECT
11-
`bfcol_1` AS `bool_col`
13+
`bfcol_2` AS `bool_col`,
14+
`bfcol_3` AS `int64_col`
1215
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_any/window_out.sql renamed to tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_any_w_window/out.sql

File renamed without changes.

0 commit comments

Comments
 (0)