Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1599,6 +1599,27 @@ def visit_DropColumns(self, op, *, parent, columns_to_drop):
)
return sg.select(*columns_to_keep).from_(parent)

def visit_TemplateSQL(
self,
op: ops.TemplateSQL,
*,
strings: tuple[str],
values: tuple[sge.Expression],
dialect: str,
):
def iter():
for s, i in itertools.zip_longest(strings, values):
if s:
yield s
if i:
yield i

str_parts = [
part if isinstance(part, str) else part.sql(dialect) for part in iter()
]
sql = "".join(str_parts)
return sg.parse_one(sql, read=dialect)

def add_query_to_expr(self, *, name: str, table: ir.Table, query: str) -> str:
dialect = self.dialect

Expand Down
82 changes: 82 additions & 0 deletions ibis/backends/tests/test_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import annotations

import datetime
import zoneinfo

import pytest

import ibis
from ibis.tests.tstring import t

tm = pytest.importorskip("pandas.testing")

five = ibis.literal(5)
world = ibis.literal("world")


@pytest.mark.notimpl(["polars"])
@pytest.mark.parametrize(
("template", "expected_result"),
[
(t("{five} + 3"), 8),
(t("{five:.2f} + 3"), 8), # format strings ignored
(t("'hello ' || {world}"), "hello world"),
(t("'hello ' || {world!r}"), "hello world"), # conversion strings ignored
],
)
def test_scalar(con, template, expected_result):
"""Test that scalar template expressions execute correctly."""
expr = ibis.sql_value(template)
result = con.execute(expr)
assert result == expected_result


@pytest.mark.xfail(
reason="sqlglot hasn't implemented inferring the dtype from this complex expression"
)
def test_complex_timestamp():
# parse a UTC timestamp into alaska local time, eg "8/1/2024 21:44:00" into 2024-08-01 13:44:00 (8 hours before UTC).
con = ibis.duckdb.connect()
timestamp = ibis.timestamp("2024-08-01 21:44:00") # noqa: F841
in_ak_time = ibis.sql_value(t("{timestamp} AT TIME ZONE 'America/Anchorage'"))
result = con.execute(in_ak_time)
expected = datetime.datetime(
2024, 8, 1, 13, 44, 0, tzinfo=zoneinfo.ZoneInfo("America/Anchorage")
)
assert result == expected


@pytest.mark.notimpl(["polars"])
def test_column(con, alltypes):
"""Test template with column interpolation."""
c = alltypes.int_col # noqa: F841
template = t("{c + 2} - 1")
expr = ibis.sql_value(template)
result = con.execute(expr)
expected = con.execute(alltypes.int_col + 1)
tm.assert_series_equal(result, expected, check_names=False)


def test_dialect():
pa = pytest.importorskip("pyarrow")
five = ibis.literal(5) # noqa: F841
template = t("CAST({five} AS REAL)")

expr_sqlite = ibis.sql_value(template, dialect="sqlite")
expr_default = ibis.sql_value(template)

con_sqlite = ibis.sqlite.connect()
result = con_sqlite.to_pyarrow(expr_default)
assert result.type == pa.float32()
assert result.as_py() == 5.0
result = con_sqlite.to_pyarrow(expr_sqlite)
assert result.type == pa.float64()
assert result.as_py() == 5.0

con_duckdb = ibis.duckdb.connect()
result = con_duckdb.to_pyarrow(expr_default)
assert result.type == pa.float32()
assert result.as_py() == 5.0
result = con_duckdb.to_pyarrow(expr_sqlite)
assert result.type == pa.float64()
assert result.as_py() == 5.0
83 changes: 83 additions & 0 deletions ibis/expr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ibis.common.temporal import normalize_datetime, normalize_timezone
from ibis.expr.datatypes import DataType
from ibis.expr.decompile import decompile
from ibis.expr.operations.template import IntoInterpolation, IntoTemplate
from ibis.expr.schema import Schema
from ibis.expr.sql import parse_sql, to_sql
from ibis.expr.types import (
Expand Down Expand Up @@ -62,6 +63,8 @@
"DataType",
"Deferred",
"Expr",
"IntoInterpolation",
"IntoTemplate",
"Scalar",
"Schema",
"Table",
Expand Down Expand Up @@ -120,6 +123,7 @@
"schema",
"selectors",
"set_backend",
"sql_value",
"struct",
"table",
"time",
Expand Down Expand Up @@ -594,6 +598,85 @@ def _deferred_method_call(expr, method_name, **kwargs):
return method(value)


def sql_value(template: IntoTemplate, /, *, dialect: str | None = None) -> ir.Value:
"""Create an ibis value from a t-string.

t-strings, or Template Strings, were added as builtin syntax in Python 3.14.
See https://docs.python.org/3.14/library/string.templatelib.html
for more information.

This function allows you to create an ibis value expression from a t-string.
It does NOT support generic SELECT statements, only expressions that
represent a single value.

Parameters
----------
template
The template to use for creating the SQL expression.
dialect
The SQL dialect to use for the expression.
Defaults to "duckdb".

Returns
-------
ValueExpr
An ibis ValueExpr.

Examples
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> con = ibis.duckdb.connect()
>>> table = con.create_table("my_table", {"a": [1, 2, 3], "b": [4, 5, 6]})

If you are using python 3.14+, you can replace the lines
below with `template = t"{table.b} + 3 - {table.a / 10}"`.
Here, since we are testing on older versions,
we use a tiny implementation of t-strings included in ibis that works as a replacement.
If you are on python < 3.14, you should use a backport such as
https://pypi.org/project/tstrings-backport and do `from tstrings import t`.

>>> from ibis.tests.tstring import t
>>> template = t("{table.b} + 3 - {table.a / 10}")

Now create an ibis expression based on this.

>>> expr = ibis.sql_value(template)
>>> print(expr.to_sql())
SELECT
"t0"."b" + 3 - "t0"."a" / 10 AS "TemplateSQL((), (b, Divide(a, 10)))"
FROM "memory"."main"."my_table" AS "t0"
>>> table.mutate(expr=expr, s=expr.cast(str) + "!")
┏━━━━━━━┳━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
┃ a ┃ b ┃ expr ┃ s ┃
┡━━━━━━━╇━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
│ int64 │ int64 │ float64 │ string │
├───────┼───────┼─────────┼────────┤
│ 1 │ 4 │ 6.9 │ 6.9! │
│ 2 │ 5 │ 7.8 │ 7.8! │
│ 3 │ 6 │ 8.7 │ 8.7! │
└───────┴───────┴─────────┴────────┘

You can provide a `dialect` parameter if you pass in a template written in
a specific SQL dialect, and then this will be transpiled to
the correct dialect upon execution.

For example, write a template in sqlite syntax (with datatype REAL)
and then execute it on duckdb (where REAL will be interpreted as DOUBLE).

>>> template = t("CAST({table.a} AS REAL)")
>>> expr = ibis.sql_value(template, dialect="sqlite")
>>> arr = con.to_pyarrow(expr)
>>> arr.type
DataType(double)
>>> arr.to_pylist()
[1.0, 2.0, 3.0]
"""
from ibis.expr.operations.template import TemplateSQL

return TemplateSQL.from_template(template, dialect=dialect).to_expr()


def desc(expr: ir.Column | str, /, *, nulls_first: bool = False) -> ir.Value:
"""Create a descending sort key from `expr` or column name.

Expand Down
1 change: 1 addition & 0 deletions ibis/expr/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ibis.expr.operations.strings import * # noqa: F403
from ibis.expr.operations.structs import * # noqa: F403
from ibis.expr.operations.subqueries import * # noqa: F403
from ibis.expr.operations.template import TemplateSQL # noqa: F401
from ibis.expr.operations.temporal import * # noqa: F403
from ibis.expr.operations.temporal_windows import * # noqa: F403
from ibis.expr.operations.udf import * # noqa: F403
Expand Down
127 changes: 127 additions & 0 deletions ibis/expr/operations/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""Operations for template strings (t-strings)."""

from __future__ import annotations

from itertools import zip_longest
from typing import TYPE_CHECKING, Optional, Protocol

import sqlglot as sg
import sqlglot.expressions as sge
from public import public
from sqlglot.optimizer.annotate_types import annotate_types
from typing_extensions import runtime_checkable

import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.rules as rlz
from ibis.common.annotations import attribute
from ibis.common.typing import VarTuple # noqa: TC001
from ibis.expr.operations.core import Value

if TYPE_CHECKING:
from collections.abc import Iterator

from ibis.backends.sql.datatypes import SqlglotType
from ibis.expr.operations.relations import Relation
from ibis.expr.types.generic import Value as ExprValue


@runtime_checkable
class IntoInterpolation(Protocol):
"""Protocol for an object that can be interpreted as a PEP 750 t-string Interpolation."""

value: ExprValue
expression: str


@runtime_checkable
class IntoTemplate(Protocol):
"""Protocol for an object that can be interpreted as a PEP 750 t-string Template."""

strings: tuple[str, ...]
interpolations: tuple[IntoInterpolation, ...]


@public
class TemplateSQL(Value):
strings: VarTuple[str]
values: VarTuple[Value]
dialect: Optional[str] = None
"""The SQL dialect the template was written in.

eg if t'CAST({val} AS REAL)', you should use 'sqlite',
since REAL is a sqlite-specific concept.
"""

def __init__(self, strings, values, dialect: str | None = None):
super().__init__(strings=strings, values=values, dialect=dialect or "duckdb")
if self.dtype.is_unknown():
raise TypeError(
f"Could not infer the dtype of the template expression with sql:\n{self.sql_for_inference}"
)

@classmethod
def from_template(
cls, template: IntoTemplate, /, *, dialect: str | None = None
) -> TemplateSQL:
return cls(
strings=template.strings,
values=[interp.value for interp in template.interpolations],
dialect=dialect,
)

@attribute
def shape(self):
if not self.values:
return ds.scalar
return rlz.highest_precedence_shape(self.values)

@attribute
def dtype(self) -> dt.DataType:
parsed = sg.parse_one(self.sql_for_inference, dialect=self.dialect)
annotated = annotate_types(parsed, dialect=self.dialect)
sqlglot_type = annotated.type
return self.type_mapper.to_ibis(sqlglot_type)

@attribute
def relations(self) -> frozenset[Relation]:
children = (n.relations for n in self.values)
return frozenset().union(*children)

@property
def sql_for_inference(self) -> str:
parts: list[str] = []
for part in self.parts:
if isinstance(part, str):
parts.append(part)
else:
ibis_type: dt.DataType = part.dtype
null_sqlglot_value = sge.cast(
sge.null(), self.type_mapper.from_ibis(ibis_type)
)
parts.append(null_sqlglot_value.sql(self.dialect))
return "".join(parts)

@property
def type_mapper(self) -> SqlglotType:
return get_type_mapper(self.dialect)

@property
def parts(self):
def iter() -> Iterator[str | Value]:
for s, i in zip_longest(self.strings, self.values):
if s:
yield s
if i:
yield i

return tuple(iter())


def get_type_mapper(dialect: str | None) -> SqlglotType:
"""Get the type mapper for the given SQL dialect."""
import importlib

module = importlib.import_module(f"ibis.backends.sql.compilers.{dialect}")
compiler_instance = module.compiler
return compiler_instance.type_mapper
Loading
Loading