|
19 | 19 |
|
20 | 20 | import pyarrow as pa
|
21 | 21 | import pytest
|
22 |
| -from datafusion import SessionContext, col, functions, lit |
| 22 | +from datafusion import ( |
| 23 | + SessionContext, |
| 24 | + col, |
| 25 | + functions, |
| 26 | + lit, |
| 27 | + lit_with_metadata, |
| 28 | + literal_with_metadata, |
| 29 | +) |
23 | 30 | from datafusion.expr import (
|
24 | 31 | Aggregate,
|
25 | 32 | AggregateFunction,
|
@@ -103,7 +110,7 @@ def test_limit(test_ctx):
|
103 | 110 |
|
104 | 111 | plan = plan.to_variant()
|
105 | 112 | assert isinstance(plan, Limit)
|
106 |
| - assert "Skip: Some(Literal(Int64(5)))" in str(plan) |
| 113 | + assert "Skip: Some(Literal(Int64(5), None))" in str(plan) |
107 | 114 |
|
108 | 115 |
|
109 | 116 | def test_aggregate_query(test_ctx):
|
@@ -824,3 +831,52 @@ def test_expr_functions(ctx, function, expected_result):
|
824 | 831 |
|
825 | 832 | assert len(result) == 1
|
826 | 833 | assert result[0].column(0).equals(expected_result)
|
| 834 | + |
| 835 | + |
| 836 | +def test_literal_metadata(ctx): |
| 837 | + result = ( |
| 838 | + ctx.from_pydict({"a": [1]}) |
| 839 | + .select( |
| 840 | + lit(1).alias("no_metadata"), |
| 841 | + lit_with_metadata(2, {"key1": "value1"}).alias("lit_with_metadata_fn"), |
| 842 | + literal_with_metadata(3, {"key2": "value2"}).alias( |
| 843 | + "literal_with_metadata_fn" |
| 844 | + ), |
| 845 | + ) |
| 846 | + .collect() |
| 847 | + ) |
| 848 | + |
| 849 | + expected_schema = pa.schema( |
| 850 | + [ |
| 851 | + pa.field("no_metadata", pa.int64(), nullable=False), |
| 852 | + pa.field( |
| 853 | + "lit_with_metadata_fn", |
| 854 | + pa.int64(), |
| 855 | + nullable=False, |
| 856 | + metadata={"key1": "value1"}, |
| 857 | + ), |
| 858 | + pa.field( |
| 859 | + "literal_with_metadata_fn", |
| 860 | + pa.int64(), |
| 861 | + nullable=False, |
| 862 | + metadata={"key2": "value2"}, |
| 863 | + ), |
| 864 | + ] |
| 865 | + ) |
| 866 | + |
| 867 | + expected = pa.RecordBatch.from_pydict( |
| 868 | + { |
| 869 | + "no_metadata": pa.array([1]), |
| 870 | + "lit_with_metadata_fn": pa.array([2]), |
| 871 | + "literal_with_metadata_fn": pa.array([3]), |
| 872 | + }, |
| 873 | + schema=expected_schema, |
| 874 | + ) |
| 875 | + |
| 876 | + assert result[0] == expected |
| 877 | + |
| 878 | + # Testing result[0].schema == expected_schema does not check each key/value pair |
| 879 | + # so we want to explicitly test these |
| 880 | + for expected_field in expected_schema: |
| 881 | + actual_field = result[0].schema.field(expected_field.name) |
| 882 | + assert expected_field.metadata == actual_field.metadata |
0 commit comments