Skip to content

Commit b6bf566

Browse files
committed
Add additional unit tests for parameterized queries
1 parent f10d958 commit b6bf566

File tree

2 files changed

+25
-6
lines changed

2 files changed

+25
-6
lines changed

python/datafusion/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ def value_to_string(value) -> str:
638638
return str(value)
639639

640640
param_values = (
641-
{name: value_to_scalar(value) for (name, value) in param_values}
641+
{name: value_to_scalar(value) for (name, value) in param_values.items()}
642642
if param_values is not None
643643
else {}
644644
)

python/tests/test_sql.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -535,18 +535,20 @@ def test_register_listing_table(
535535
assert dict(zip(rd["grp"], rd["count"])) == {"a": 3, "b": 2}
536536

537537

538-
def test_parameterized_df_in_sql(ctx, tmp_path) -> None:
538+
def test_parameterized_named_params(ctx, tmp_path) -> None:
539539
path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())
540540

541541
df = ctx.read_parquet(path)
542542
result = ctx.sql(
543-
"SELECT COUNT(a) AS cnt FROM $replaced_df", replaced_df=df
543+
"SELECT COUNT(a) AS cnt, $lit_val as lit_val FROM $replaced_df",
544+
lit_val=3,
545+
replaced_df=df,
544546
).collect()
545547
result = pa.Table.from_batches(result)
546-
assert result.to_pydict() == {"cnt": [100]}
548+
assert result.to_pydict() == {"cnt": [100], "lit_val": [3]}
547549

548550

549-
def test_parameterized_pass_through_in_sql(ctx: SessionContext) -> None:
551+
def test_parameterized_param_values(ctx: SessionContext) -> None:
550552
# Test the parameters that should be handled by the parser rather
551553
# than our manipulation of the query string by searching for tokens
552554
batch = pa.RecordBatch.from_arrays(
@@ -555,5 +557,22 @@ def test_parameterized_pass_through_in_sql(ctx: SessionContext) -> None:
555557
)
556558

557559
ctx.register_record_batches("t", [[batch]])
558-
result = ctx.sql("SELECT a FROM t WHERE a < $val", val=3)
560+
result = ctx.sql("SELECT a FROM t WHERE a < $val", param_values={"val": 3})
561+
assert result.to_pydict() == {"a": [1, 2]}
562+
563+
564+
def test_parameterized_mixed_query(ctx: SessionContext) -> None:
565+
batch = pa.RecordBatch.from_arrays(
566+
[pa.array([1, 2, 3, 4])],
567+
names=["a"],
568+
)
569+
ctx.register_record_batches("t", [[batch]])
570+
registered_df = ctx.table("t")
571+
572+
result = ctx.sql(
573+
"SELECT $col_name FROM $df WHERE a < $val",
574+
param_values={"val": 3},
575+
df=registered_df,
576+
col_name="a",
577+
)
559578
assert result.to_pydict() == {"a": [1, 2]}

0 commit comments

Comments
 (0)