Skip to content

Commit 85b00d5

Browse files
authored
Merge pull request #228 from awslabs/dev
Bumping version to 1.1.2
2 parents c693b26 + 84faf63 commit 85b00d5

File tree

13 files changed

+412
-391
lines changed

13 files changed

+412
-391
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Due the new major version `1.0.0` with breaking changes, please make sure that a
1111

1212
![AWS Data Wrangler](docs/source/_static/logo2.png?raw=true "AWS Data Wrangler")
1313

14-
[![Release](https://img.shields.io/badge/release-1.1.1-brightgreen.svg)](https://pypi.org/project/awswrangler/)
14+
[![Release](https://img.shields.io/badge/release-1.1.2-brightgreen.svg)](https://pypi.org/project/awswrangler/)
1515
[![Python Version](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8-brightgreen.svg)](https://anaconda.org/conda-forge/awswrangler)
1616
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
1717
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)

awswrangler/__metadata__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77

88
__title__ = "awswrangler"
99
__description__ = "Pandas on AWS."
10-
__version__ = "1.1.1"
10+
__version__ = "1.1.2"
1111
__license__ = "Apache License 2.0"

awswrangler/_data_types.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,14 @@ def pyarrow2athena(dtype: pa.DataType) -> str: # pylint: disable=too-many-branc
114114
"""Pyarrow to Athena data types conversion."""
115115
if pa.types.is_int8(dtype):
116116
return "tinyint"
117-
if pa.types.is_int16(dtype):
117+
if pa.types.is_int16(dtype) or pa.types.is_uint8(dtype):
118118
return "smallint"
119-
if pa.types.is_int32(dtype):
119+
if pa.types.is_int32(dtype) or pa.types.is_uint16(dtype):
120120
return "int"
121-
if pa.types.is_int64(dtype):
121+
if pa.types.is_int64(dtype) or pa.types.is_uint32(dtype):
122122
return "bigint"
123+
if pa.types.is_uint64(dtype):
124+
raise exceptions.UnsupportedType("There is no support for uint64, please consider int64 or uint32.")
123125
if pa.types.is_float32(dtype):
124126
return "float"
125127
if pa.types.is_float64(dtype):

awswrangler/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,9 @@ def get_fs(
136136
fs: s3fs.S3FileSystem = s3fs.S3FileSystem(
137137
anon=False,
138138
use_ssl=True,
139-
default_cache_type="none",
139+
default_cache_type="readahead",
140140
default_fill_cache=False,
141-
default_block_size=134_217_728, # 128 MB (50 * 2**20)
141+
default_block_size=1_073_741_824, # 1024 MB (1024 * 2**20)
142142
config_kwargs={"retries": {"max_attempts": 15}},
143143
session=ensure_session(session=session)._session, # pylint: disable=protected-access
144144
s3_additional_kwargs=s3_additional_kwargs,

awswrangler/catalog.py

Lines changed: 165 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import re
77
import unicodedata
8-
from typing import Any, Dict, Iterator, List, Optional, Tuple
8+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
99
from urllib.parse import quote_plus
1010

1111
import boto3 # type: ignore
@@ -978,11 +978,21 @@ def _create_table(
978978
session: boto3.Session = _utils.ensure_session(session=boto3_session)
979979
client_glue: boto3.client = _utils.client(service_name="glue", session=session)
980980
exist: bool = does_table_exist(database=database, table=table, boto3_session=session)
981-
if mode not in ("overwrite", "append"): # pragma: no cover
982-
raise exceptions.InvalidArgument(f"{mode} is not a valid mode. It must be 'overwrite' or 'append'.")
981+
if mode not in ("overwrite", "append", "overwrite_partitions"): # pragma: no cover
982+
raise exceptions.InvalidArgument(
983+
f"{mode} is not a valid mode. It must be 'overwrite', 'append' or 'overwrite_partitions'."
984+
)
983985
if (exist is True) and (mode == "overwrite"):
984986
skip_archive: bool = not catalog_versioning
987+
partitions_values: List[List[str]] = list(
988+
_get_partitions(database=database, table=table, boto3_session=session).values()
989+
)
990+
client_glue.batch_delete_partition(
991+
DatabaseName=database, TableName=table, PartitionsToDelete=[{"Values": v} for v in partitions_values]
992+
)
985993
client_glue.update_table(DatabaseName=database, TableInput=table_input, SkipArchive=skip_archive)
994+
elif (exist is True) and (mode in ("append", "overwrite_partitions")) and (parameters is not None):
995+
upsert_table_parameters(parameters=parameters, database=database, table=table, boto3_session=session)
986996
elif exist is False:
987997
client_glue.create_table(DatabaseName=database, TableInput=table_input)
988998

@@ -1327,3 +1337,155 @@ def extract_athena_types(
13271337
return _data_types.athena_types_from_pandas_partitioned(
13281338
df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=index_left
13291339
)
1340+
1341+
1342+
def get_table_parameters(
1343+
database: str, table: str, catalog_id: Optional[str] = None, boto3_session: Optional[boto3.Session] = None
1344+
) -> Dict[str, str]:
1345+
"""Get all parameters.
1346+
1347+
Parameters
1348+
----------
1349+
database : str
1350+
Database name.
1351+
table : str
1352+
Table name.
1353+
catalog_id : str, optional
1354+
The ID of the Data Catalog from which to retrieve Databases.
1355+
If none is provided, the AWS account ID is used by default.
1356+
boto3_session : boto3.Session(), optional
1357+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
1358+
1359+
Returns
1360+
-------
1361+
Dict[str, str]
1362+
Dictionary of parameters.
1363+
1364+
Examples
1365+
--------
1366+
>>> import awswrangler as wr
1367+
>>> pars = wr.catalog.get_table_parameters(database="...", table="...")
1368+
1369+
"""
1370+
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
1371+
args: Dict[str, str] = {}
1372+
if catalog_id is not None:
1373+
args["CatalogId"] = catalog_id # pragma: no cover
1374+
args["DatabaseName"] = database
1375+
args["Name"] = table
1376+
response: Dict[str, Any] = client_glue.get_table(**args)
1377+
parameters: Dict[str, str] = response["Table"]["Parameters"]
1378+
return parameters
1379+
1380+
1381+
def upsert_table_parameters(
1382+
parameters: Dict[str, str],
1383+
database: str,
1384+
table: str,
1385+
catalog_id: Optional[str] = None,
1386+
boto3_session: Optional[boto3.Session] = None,
1387+
) -> Dict[str, str]:
1388+
"""Insert or Update the received parameters.
1389+
1390+
Parameters
1391+
----------
1392+
parameters : Dict[str, str]
1393+
e.g. {"source": "mysql", "destination": "datalake"}
1394+
database : str
1395+
Database name.
1396+
table : str
1397+
Table name.
1398+
catalog_id : str, optional
1399+
The ID of the Data Catalog from which to retrieve Databases.
1400+
If none is provided, the AWS account ID is used by default.
1401+
boto3_session : boto3.Session(), optional
1402+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
1403+
1404+
Returns
1405+
-------
1406+
Dict[str, str]
1407+
All parameters after the upsert.
1408+
1409+
Examples
1410+
--------
1411+
>>> import awswrangler as wr
1412+
>>> pars = wr.catalog.upsert_table_parameters(
1413+
... parameters={"source": "mysql", "destination": "datalake"},
1414+
... database="...",
1415+
... table="...")
1416+
1417+
"""
1418+
session: boto3.Session = _utils.ensure_session(session=boto3_session)
1419+
pars: Dict[str, str] = get_table_parameters(
1420+
database=database, table=table, catalog_id=catalog_id, boto3_session=session
1421+
)
1422+
for k, v in parameters.items():
1423+
pars[k] = v
1424+
overwrite_table_parameters(
1425+
parameters=pars, database=database, table=table, catalog_id=catalog_id, boto3_session=session
1426+
)
1427+
return pars
1428+
1429+
1430+
def overwrite_table_parameters(
1431+
parameters: Dict[str, str],
1432+
database: str,
1433+
table: str,
1434+
catalog_id: Optional[str] = None,
1435+
boto3_session: Optional[boto3.Session] = None,
1436+
) -> Dict[str, str]:
1437+
"""Overwrite all existing parameters.
1438+
1439+
Parameters
1440+
----------
1441+
parameters : Dict[str, str]
1442+
e.g. {"source": "mysql", "destination": "datalake"}
1443+
database : str
1444+
Database name.
1445+
table : str
1446+
Table name.
1447+
catalog_id : str, optional
1448+
The ID of the Data Catalog from which to retrieve Databases.
1449+
If none is provided, the AWS account ID is used by default.
1450+
boto3_session : boto3.Session(), optional
1451+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
1452+
1453+
Returns
1454+
-------
1455+
Dict[str, str]
1456+
All parameters after the overwrite (The same received).
1457+
1458+
Examples
1459+
--------
1460+
>>> import awswrangler as wr
1461+
>>> pars = wr.catalog.overwrite_table_parameters(
1462+
... parameters={"source": "mysql", "destination": "datalake"},
1463+
... database="...",
1464+
... table="...")
1465+
1466+
"""
1467+
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
1468+
args: Dict[str, str] = {}
1469+
if catalog_id is not None:
1470+
args["CatalogId"] = catalog_id # pragma: no cover
1471+
args["DatabaseName"] = database
1472+
args["Name"] = table
1473+
response: Dict[str, Any] = client_glue.get_table(**args)
1474+
response["Table"]["Parameters"] = parameters
1475+
if "DatabaseName" in response["Table"]:
1476+
del response["Table"]["DatabaseName"]
1477+
if "CreateTime" in response["Table"]:
1478+
del response["Table"]["CreateTime"]
1479+
if "UpdateTime" in response["Table"]:
1480+
del response["Table"]["UpdateTime"]
1481+
if "CreatedBy" in response["Table"]:
1482+
del response["Table"]["CreatedBy"]
1483+
if "IsRegisteredWithLakeFormation" in response["Table"]:
1484+
del response["Table"]["IsRegisteredWithLakeFormation"]
1485+
args2: Dict[str, Union[str, Dict[str, Any]]] = {}
1486+
if catalog_id is not None:
1487+
args2["CatalogId"] = catalog_id # pragma: no cover
1488+
args2["DatabaseName"] = database
1489+
args2["TableInput"] = response["Table"]
1490+
client_glue.update_table(**args2)
1491+
return parameters

awswrangler/db.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def copy_to_redshift( # pylint: disable=too-many-arguments
422422
diststyle: str = "AUTO",
423423
distkey: Optional[str] = None,
424424
sortstyle: str = "COMPOUND",
425-
sortkey: Optional[str] = None,
425+
sortkey: Optional[List[str]] = None,
426426
primary_keys: Optional[List[str]] = None,
427427
varchar_lengths_default: int = 256,
428428
varchar_lengths: Optional[Dict[str, int]] = None,
@@ -485,7 +485,7 @@ def copy_to_redshift( # pylint: disable=too-many-arguments
485485
sortstyle : str
486486
Sorting can be "COMPOUND" or "INTERLEAVED".
487487
https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html
488-
sortkey : str, optional
488+
sortkey : List[str], optional
489489
List of columns to be sorted.
490490
primary_keys : List[str], optional
491491
Primary keys.
@@ -569,7 +569,7 @@ def copy_files_to_redshift( # pylint: disable=too-many-locals,too-many-argument
569569
diststyle: str = "AUTO",
570570
distkey: Optional[str] = None,
571571
sortstyle: str = "COMPOUND",
572-
sortkey: Optional[str] = None,
572+
sortkey: Optional[List[str]] = None,
573573
primary_keys: Optional[List[str]] = None,
574574
varchar_lengths_default: int = 256,
575575
varchar_lengths: Optional[Dict[str, int]] = None,
@@ -616,7 +616,7 @@ def copy_files_to_redshift( # pylint: disable=too-many-locals,too-many-argument
616616
sortstyle : str
617617
Sorting can be "COMPOUND" or "INTERLEAVED".
618618
https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html
619-
sortkey : str, optional
619+
sortkey : List[str], optional
620620
List of columns to be sorted.
621621
primary_keys : List[str], optional
622622
Primary keys.
@@ -716,7 +716,7 @@ def _rs_create_table(
716716
diststyle: str,
717717
sortstyle: str,
718718
distkey: Optional[str] = None,
719-
sortkey: Optional[str] = None,
719+
sortkey: Optional[List[str]] = None,
720720
primary_keys: Optional[List[str]] = None,
721721
) -> Tuple[str, Optional[str]]:
722722
if mode == "overwrite":
@@ -754,7 +754,7 @@ def _rs_create_table(
754754

755755

756756
def _rs_validate_parameters(
757-
redshift_types: Dict[str, str], diststyle: str, distkey: Optional[str], sortstyle: str, sortkey: Optional[str]
757+
redshift_types: Dict[str, str], diststyle: str, distkey: Optional[str], sortstyle: str, sortkey: Optional[List[str]]
758758
) -> None:
759759
if diststyle not in _RS_DISTSTYLES:
760760
raise exceptions.InvalidRedshiftDiststyle(f"diststyle must be in {_RS_DISTSTYLES}")

0 commit comments

Comments
 (0)