Skip to content

Commit 5e35889

Browse files
authored
Merge pull request #43 from awslabs/normalize-partitions-names
Normalizing partitions and casting columns names
2 parents 1dafb46 + 6d59d8e commit 5e35889

File tree

2 files changed

+74
-8
lines changed

2 files changed

+74
-8
lines changed

awswrangler/pandas.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
EmptyDataframe, InvalidSerDe,
1818
InvalidCompression)
1919
from awswrangler.utils import calculate_bounders
20-
from awswrangler import s3, athena
20+
from awswrangler import s3
21+
from awswrangler.athena import Athena
2122

2223
logger = logging.getLogger(__name__)
2324

@@ -607,8 +608,21 @@ def to_s3(self,
607608
:param inplace: True is cheapest (CPU and Memory) but False leaves your DataFrame intact
608609
:return: List of objects written on S3
609610
"""
611+
if not partition_cols:
612+
partition_cols = []
613+
if not cast_columns:
614+
cast_columns = {}
610615
dataframe = Pandas.normalize_columns_names_athena(dataframe,
611616
inplace=inplace)
617+
cast_columns = {
618+
Athena.normalize_column_name(k): v
619+
for k, v in cast_columns.items()
620+
}
621+
logger.debug(f"cast_columns: {cast_columns}")
622+
partition_cols = [
623+
Athena.normalize_column_name(x) for x in partition_cols
624+
]
625+
logger.debug(f"partition_cols: {partition_cols}")
612626
dataframe = Pandas.drop_duplicated_columns(dataframe=dataframe,
613627
inplace=inplace)
614628
if compression is not None:
@@ -628,8 +642,6 @@ def to_s3(self,
628642
raise UnsupportedFileFormat(file_format)
629643
if dataframe.empty:
630644
raise EmptyDataframe()
631-
if not partition_cols:
632-
partition_cols = []
633645
if ((mode == "overwrite")
634646
or ((mode == "overwrite_partitions") and # noqa
635647
(not partition_cols))):
@@ -1042,7 +1054,7 @@ def normalize_columns_names_athena(dataframe, inplace=True):
10421054
if inplace is False:
10431055
dataframe = dataframe.copy(deep=True)
10441056
dataframe.columns = [
1045-
athena.Athena.normalize_column_name(x) for x in dataframe.columns
1057+
Athena.normalize_column_name(x) for x in dataframe.columns
10461058
]
10471059
return dataframe
10481060

testing/test_awswrangler/test_pandas.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -781,16 +781,22 @@ def test_read_sql_athena_with_time_zone(session, bucket, database):
781781

782782
def test_normalize_columns_names_athena():
783783
dataframe = pandas.DataFrame({
784-
"CammelCase": [1, 2, 3],
784+
"CamelCase": [1, 2, 3],
785785
"With Spaces": [4, 5, 6],
786786
"With-Dash": [7, 8, 9],
787787
"Ãccént": [10, 11, 12],
788+
"with.dot": [10, 11, 12],
789+
"Camel_Case2": [13, 14, 15],
790+
"Camel___Case3": [16, 17, 18]
788791
})
789792
Pandas.normalize_columns_names_athena(dataframe=dataframe, inplace=True)
790-
assert dataframe.columns[0] == "cammel_case"
793+
assert dataframe.columns[0] == "camel_case"
791794
assert dataframe.columns[1] == "with_spaces"
792795
assert dataframe.columns[2] == "with_dash"
793796
assert dataframe.columns[3] == "accent"
797+
assert dataframe.columns[4] == "with_dot"
798+
assert dataframe.columns[5] == "camel_case2"
799+
assert dataframe.columns[6] == "camel_case3"
794800

795801

796802
def test_to_parquet_with_normalize(
@@ -799,11 +805,13 @@ def test_to_parquet_with_normalize(
799805
database,
800806
):
801807
dataframe = pandas.DataFrame({
802-
"CammelCase": [1, 2, 3],
808+
"CamelCase": [1, 2, 3],
803809
"With Spaces": [4, 5, 6],
804810
"With-Dash": [7, 8, 9],
805811
"Ãccént": [10, 11, 12],
806812
"with.dot": [10, 11, 12],
813+
"Camel_Case2": [13, 14, 15],
814+
"Camel___Case3": [16, 17, 18]
807815
})
808816
session.pandas.to_parquet(dataframe=dataframe,
809817
database=database,
@@ -818,11 +826,57 @@ def test_to_parquet_with_normalize(
818826
sleep(2)
819827
assert len(dataframe.index) == len(dataframe2.index)
820828
assert (len(list(dataframe.columns)) + 1) == len(list(dataframe2.columns))
821-
assert dataframe2.columns[0] == "cammel_case"
829+
assert dataframe2.columns[0] == "camel_case"
822830
assert dataframe2.columns[1] == "with_spaces"
823831
assert dataframe2.columns[2] == "with_dash"
824832
assert dataframe2.columns[3] == "accent"
825833
assert dataframe2.columns[4] == "with_dot"
834+
assert dataframe2.columns[5] == "camel_case2"
835+
assert dataframe2.columns[6] == "camel_case3"
836+
837+
838+
def test_to_parquet_with_normalize_and_cast(
839+
session,
840+
bucket,
841+
database,
842+
):
843+
dataframe = pandas.DataFrame({
844+
"CamelCase": [1, 2, 3],
845+
"With Spaces": [4, 5, 6],
846+
"With-Dash": [7, 8, 9],
847+
"Ãccént": [10, 11, 12],
848+
"with.dot": [10, 11, 12],
849+
"Camel_Case2": [13, 14, 15],
850+
"Camel___Case3": [16, 17, 18]
851+
})
852+
session.pandas.to_parquet(dataframe=dataframe,
853+
database=database,
854+
path=f"s3://{bucket}/TestTable-with.dot/",
855+
mode="overwrite",
856+
partition_cols=["CamelCase"],
857+
cast_columns={
858+
"Camel_Case2": "double",
859+
"Camel___Case3": "float"
860+
})
861+
dataframe2 = None
862+
for counter in range(10):
863+
dataframe2 = session.pandas.read_sql_athena(
864+
sql="select * from test_table_with_dot", database=database)
865+
if len(dataframe.index) == len(dataframe2.index):
866+
break
867+
sleep(2)
868+
assert len(dataframe.index) == len(dataframe2.index)
869+
assert (len(list(dataframe.columns)) + 1) == len(list(dataframe2.columns))
870+
assert dataframe2.columns[0] == "with_spaces"
871+
assert dataframe2.columns[1] == "with_dash"
872+
assert dataframe2.columns[2] == "accent"
873+
assert dataframe2.columns[3] == "with_dot"
874+
assert dataframe2.columns[4] == "camel_case2"
875+
assert dataframe2.columns[5] == "camel_case3"
876+
assert dataframe2.columns[6] == "__index_level_0__"
877+
assert dataframe2.columns[7] == "camel_case"
878+
assert dataframe2[dataframe2.columns[4]].dtype == "float64"
879+
assert dataframe2[dataframe2.columns[5]].dtype == "float64"
826880

827881

828882
def test_drop_duplicated_columns():

0 commit comments

Comments
 (0)