Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1739,7 +1739,6 @@ class Generator(generator.Generator):
exp.RegrValy: _regr_val_sql,
exp.Return: lambda self, e: self.sql(e, "this"),
exp.ReturnsProperty: lambda self, e: "TABLE" if isinstance(e.this, exp.Schema) else "",
exp.StrPosition: strposition_sql,
exp.StrToUnix: lambda self, e: self.func(
"EPOCH", self.func("STRPTIME", e.this, self.format_time(e))
),
Expand Down Expand Up @@ -2563,6 +2562,26 @@ def approxtopk_sql(self, expression: exp.ApproxTopK) -> str:
def fromiso8601timestamp_sql(self, expression: exp.FromISO8601Timestamp) -> str:
return self.sql(exp.cast(expression.this, exp.DType.TIMESTAMPTZ))

def strposition_sql(self, expression: exp.StrPosition) -> str:
position = expression.args.get("position")
if (
expression.args.get("clamp_position")
and position
and not (position.is_number and position.to_py() > 0)
):
Comment on lines +2566 to +2571
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should apply the logic on the transpiled query, right ?

WITH t as (select -1 as p) SELECT CHARINDEX('l', 'hello world', p) from t

We don't cover this ^ case. On parse time we don't know the value of p.

expression = expression.copy()
expression.set(
"position",
exp.Literal.number(1)
if position.is_number
else exp.If(
this=exp.LTE(this=position, expression=exp.Literal.number(0)),
true=exp.Literal.number(1),
false=position.copy(),
),
)
return strposition_sql(self, expression)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just always generate the check=> If the postion is present apply the check.

So we can avoid this checkand not (position.is_number and position.to_py() > 0)
and this if position.is_number.

I think it doesn't break a lot of tranpsilation tests.


def strtotime_sql(self, expression: exp.StrToTime) -> str:
# Check if target_type requires TIMESTAMPTZ (for LTZ/TZ variants)
target_type = expression.args.get("target_type")
Expand Down
1 change: 1 addition & 0 deletions sqlglot/expressions/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class StrPosition(Expression, Func):
"substr": True,
"position": False,
"occurrence": False,
"clamp_position": False,
}


Expand Down
6 changes: 6 additions & 0 deletions sqlglot/parsers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,12 @@ class SnowflakeParser(parser.Parser):

FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"CHARINDEX": lambda args: exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
clamp_position=True,
),
"ADD_MONTHS": lambda args: exp.AddMonths(
this=seq_get(args, 0),
expression=seq_get(args, 1),
Expand Down
16 changes: 16 additions & 0 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -6503,6 +6503,22 @@ def test_space(self):
},
)

def test_charindex(self):
self.validate_all(
"SELECT CHARINDEX('sub', 'testsubstring', -1)",
write={
"snowflake": "SELECT CHARINDEX('sub', 'testsubstring', -1)",
"duckdb": "SELECT CASE WHEN STRPOS(SUBSTRING('testsubstring', 1), 'sub') = 0 THEN 0 ELSE STRPOS(SUBSTRING('testsubstring', 1), 'sub') + 1 - 1 END",
},
)
self.validate_all(
"SELECT CHARINDEX('sub', 'testsubstring', p)",
write={
"snowflake": "SELECT CHARINDEX('sub', 'testsubstring', p)",
"duckdb": "SELECT CASE WHEN STRPOS(SUBSTRING('testsubstring', CASE WHEN p <= 0 THEN 1 ELSE p END), 'sub') = 0 THEN 0 ELSE STRPOS(SUBSTRING('testsubstring', CASE WHEN p <= 0 THEN 1 ELSE p END), 'sub') + CASE WHEN p <= 0 THEN 1 ELSE p END - 1 END",
},
)

def test_directed_joins(self):
self.validate_identity("SELECT * FROM a CROSS DIRECTED JOIN b USING (id)")
self.validate_identity("SELECT * FROM a INNER DIRECTED JOIN b USING (id)")
Expand Down
Loading