Skip to content

Commit 2b6935f

Browse files
authored
Insert multiple rows at once in to_sql (#600)
1 parent 93aca5d commit 2b6935f

File tree

11 files changed

+394
-315
lines changed

11 files changed

+394
-315
lines changed

awswrangler/_config.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class _ConfigArg(NamedTuple):
3434
"max_local_cache_entries": _ConfigArg(dtype=int, nullable=False),
3535
"s3_block_size": _ConfigArg(dtype=int, nullable=False, enforced=True),
3636
"workgroup": _ConfigArg(dtype=str, nullable=False, enforced=True),
37+
"chunksize": _ConfigArg(dtype=int, nullable=False, enforced=True),
3738
# Endpoints URLs
3839
"s3_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True),
3940
"athena_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True),
@@ -47,7 +48,7 @@ class _ConfigArg(NamedTuple):
4748
}
4849

4950

50-
class _Config: # pylint: disable=too-many-instance-attributes
51+
class _Config: # pylint: disable=too-many-instance-attributes, too-many-public-methods
5152
"""Wrangler's Configuration class."""
5253

5354
def __init__(self) -> None:
@@ -279,6 +280,15 @@ def workgroup(self) -> Optional[str]:
279280
def workgroup(self, value: Optional[str]) -> None:
280281
self._set_config_value(key="workgroup", value=value)
281282

283+
@property
284+
def chunksize(self) -> int:
285+
"""Property chunksize."""
286+
return cast(int, self["chunksize"])
287+
288+
@chunksize.setter
289+
def chunksize(self, value: int) -> None:
290+
self._set_config_value(key="chunksize", value=value)
291+
282292
@property
283293
def s3_endpoint_url(self) -> Optional[str]:
284294
"""Property s3_endpoint_url."""

awswrangler/_databases.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Databases Utilities."""
22

33
import logging
4-
from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple, Union, cast
4+
from typing import Any, Dict, Generator, Iterator, List, NamedTuple, Optional, Tuple, Union, cast
55

66
import boto3
77
import pandas as pd
@@ -219,13 +219,22 @@ def read_sql_query(
219219
raise
220220

221221

222-
def extract_parameters(df: pd.DataFrame) -> List[List[Any]]:
223-
"""Extract Parameters."""
224-
parameters: List[List[Any]] = df.values.tolist()
225-
for i, row in enumerate(parameters):
226-
for j, value in enumerate(row):
227-
if pd.isna(value):
228-
parameters[i][j] = None
229-
elif hasattr(value, "to_pydatetime"):
230-
parameters[i][j] = value.to_pydatetime()
231-
return parameters
222+
def generate_placeholder_parameter_pairs(
223+
df: pd.DataFrame, column_placeholders: str, chunksize: int
224+
) -> Generator[Tuple[str, List[Any]], None, None]:
225+
"""Extract Placeholder and Parameter pairs."""
226+
227+
def convert_value_to_native_python_type(value: Any) -> Any:
228+
if pd.isna(value):
229+
return None
230+
if hasattr(value, "to_pydatetime"):
231+
return value.to_pydatetime()
232+
233+
return value
234+
235+
parameters = df.values.tolist()
236+
for i in range(0, len(df.index), chunksize):
237+
parameters_chunk = parameters[i : i + chunksize]
238+
chunk_placeholders = ", ".join([f"({column_placeholders})" for _ in range(len(parameters_chunk))])
239+
flattened_chunk = [convert_value_to_native_python_type(value) for row in parameters_chunk for value in row]
240+
yield chunk_placeholders, flattened_chunk

awswrangler/mysql.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from awswrangler import _data_types
1313
from awswrangler import _databases as _db_utils
1414
from awswrangler import exceptions
15+
from awswrangler._config import apply_configs
1516

1617
_logger: logging.Logger = logging.getLogger(__name__)
1718

@@ -257,6 +258,7 @@ def read_sql_table(
257258
)
258259

259260

261+
@apply_configs
260262
def to_sql(
261263
df: pd.DataFrame,
262264
con: pymysql.connections.Connection,
@@ -267,6 +269,7 @@ def to_sql(
267269
dtype: Optional[Dict[str, str]] = None,
268270
varchar_lengths: Optional[Dict[str, int]] = None,
269271
use_column_names: bool = False,
272+
chunksize: int = 200,
270273
) -> None:
271274
"""Write records stored in a DataFrame into MySQL.
272275
@@ -295,6 +298,8 @@ def to_sql(
295298
If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query.
296299
E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be
297300
inserted into the database columns `col1` and `col3`.
301+
chunksize: int
302+
Number of rows which are inserted with each SQL query. Defaults to inserting 200 rows per query.
298303
299304
Returns
300305
-------
@@ -308,7 +313,7 @@ def to_sql(
308313
>>> import awswrangler as wr
309314
>>> con = wr.mysql.connect("MY_GLUE_CONNECTION")
310315
>>> wr.mysql.to_sql(
311-
... df=df
316+
... df=df,
312317
... table="my_table",
313318
... schema="test",
314319
... con=con
@@ -333,14 +338,17 @@ def to_sql(
333338
)
334339
if index:
335340
df.reset_index(level=df.index.names, inplace=True)
336-
placeholders: str = ", ".join(["%s"] * len(df.columns))
341+
column_placeholders: str = ", ".join(["%s"] * len(df.columns))
337342
insertion_columns = ""
338343
if use_column_names:
339344
insertion_columns = f"({', '.join(df.columns)})"
340-
sql: str = f"INSERT INTO `{schema}`.`{table}` {insertion_columns} VALUES ({placeholders})"
341-
_logger.debug("sql: %s", sql)
342-
parameters: List[List[Any]] = _db_utils.extract_parameters(df=df)
343-
cursor.executemany(sql, parameters)
345+
placeholder_parameter_pair_generator = _db_utils.generate_placeholder_parameter_pairs(
346+
df=df, column_placeholders=column_placeholders, chunksize=chunksize
347+
)
348+
for placeholders, parameters in placeholder_parameter_pair_generator:
349+
sql: str = f"INSERT INTO `{schema}`.`{table}` {insertion_columns} VALUES {placeholders}"
350+
_logger.debug("sql: %s", sql)
351+
cursor.executemany(sql, (parameters,))
344352
con.commit()
345353
except Exception as ex:
346354
con.rollback()

awswrangler/postgresql.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from awswrangler import _data_types
1212
from awswrangler import _databases as _db_utils
1313
from awswrangler import exceptions
14+
from awswrangler._config import apply_configs
1415

1516
_logger: logging.Logger = logging.getLogger(__name__)
1617

@@ -263,6 +264,7 @@ def read_sql_table(
263264
)
264265

265266

267+
@apply_configs
266268
def to_sql(
267269
df: pd.DataFrame,
268270
con: pg8000.Connection,
@@ -273,6 +275,7 @@ def to_sql(
273275
dtype: Optional[Dict[str, str]] = None,
274276
varchar_lengths: Optional[Dict[str, int]] = None,
275277
use_column_names: bool = False,
278+
chunksize: int = 200,
276279
) -> None:
277280
"""Write records stored in a DataFrame into PostgreSQL.
278281
@@ -301,6 +304,8 @@ def to_sql(
301304
If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query.
302305
E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be
303306
inserted into the database columns `col1` and `col3`.
307+
chunksize: int
308+
Number of rows which are inserted with each SQL query. Defaults to inserting 200 rows per query.
304309
305310
Returns
306311
-------
@@ -339,14 +344,17 @@ def to_sql(
339344
)
340345
if index:
341346
df.reset_index(level=df.index.names, inplace=True)
342-
placeholders: str = ", ".join(["%s"] * len(df.columns))
347+
column_placeholders: str = ", ".join(["%s"] * len(df.columns))
343348
insertion_columns = ""
344349
if use_column_names:
345350
insertion_columns = f"({', '.join(df.columns)})"
346-
sql: str = f'INSERT INTO "{schema}"."{table}" {insertion_columns} VALUES ({placeholders})'
347-
_logger.debug("sql: %s", sql)
348-
parameters: List[List[Any]] = _db_utils.extract_parameters(df=df)
349-
cursor.executemany(sql, parameters)
351+
placeholder_parameter_pair_generator = _db_utils.generate_placeholder_parameter_pairs(
352+
df=df, column_placeholders=column_placeholders, chunksize=chunksize
353+
)
354+
for placeholders, parameters in placeholder_parameter_pair_generator:
355+
sql: str = f'INSERT INTO "{schema}"."{table}" {insertion_columns} VALUES {placeholders}'
356+
_logger.debug("sql: %s", sql)
357+
cursor.executemany(sql, (parameters,))
350358
con.commit()
351359
except Exception as ex:
352360
con.rollback()

awswrangler/redshift.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from awswrangler import _data_types
1414
from awswrangler import _databases as _db_utils
1515
from awswrangler import _utils, exceptions, s3
16+
from awswrangler._config import apply_configs
1617

1718
_logger: logging.Logger = logging.getLogger(__name__)
1819

@@ -629,6 +630,7 @@ def read_sql_table(
629630
)
630631

631632

633+
@apply_configs
632634
def to_sql(
633635
df: pd.DataFrame,
634636
con: redshift_connector.Connection,
@@ -645,6 +647,7 @@ def to_sql(
645647
varchar_lengths_default: int = 256,
646648
varchar_lengths: Optional[Dict[str, int]] = None,
647649
use_column_names: bool = False,
650+
chunksize: int = 200,
648651
) -> None:
649652
"""Write records stored in a DataFrame into Redshift.
650653
@@ -693,6 +696,8 @@ def to_sql(
693696
If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query.
694697
E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be
695698
inserted into the database columns `col1` and `col3`.
699+
chunksize: int
700+
Number of rows which are inserted with each SQL query. Defaults to inserting 200 rows per query.
696701
697702
Returns
698703
-------
@@ -706,7 +711,7 @@ def to_sql(
706711
>>> import awswrangler as wr
707712
>>> con = wr.redshift.connect("MY_GLUE_CONNECTION")
708713
>>> wr.redshift.to_sql(
709-
... df=df
714+
... df=df,
710715
... table="my_table",
711716
... schema="public",
712717
... con=con
@@ -740,15 +745,18 @@ def to_sql(
740745
)
741746
if index:
742747
df.reset_index(level=df.index.names, inplace=True)
743-
placeholders: str = ", ".join(["%s"] * len(df.columns))
748+
column_placeholders: str = ", ".join(["%s"] * len(df.columns))
744749
schema_str = f'"{created_schema}".' if created_schema else ""
745750
insertion_columns = ""
746751
if use_column_names:
747752
insertion_columns = f"({', '.join(df.columns)})"
748-
sql: str = f'INSERT INTO {schema_str}"{created_table}" {insertion_columns} VALUES ({placeholders})'
749-
_logger.debug("sql: %s", sql)
750-
parameters: List[List[Any]] = _db_utils.extract_parameters(df=df)
751-
cursor.executemany(sql, parameters)
753+
placeholder_parameter_pair_generator = _db_utils.generate_placeholder_parameter_pairs(
754+
df=df, column_placeholders=column_placeholders, chunksize=chunksize
755+
)
756+
for placeholders, parameters in placeholder_parameter_pair_generator:
757+
sql: str = f'INSERT INTO {schema_str}"{created_table}" {insertion_columns} VALUES {placeholders}'
758+
_logger.debug("sql: %s", sql)
759+
cursor.executemany(sql, (parameters,))
752760
if table != created_table: # upsert
753761
_upsert(cursor=cursor, schema=schema, table=table, temp_table=created_table, primary_keys=primary_keys)
754762
con.commit()

awswrangler/sqlserver.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
import importlib.util
5+
import inspect
56
import logging
67
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar, Union
78

@@ -12,6 +13,7 @@
1213
from awswrangler import _data_types
1314
from awswrangler import _databases as _db_utils
1415
from awswrangler import exceptions
16+
from awswrangler._config import apply_configs
1517

1618
__all__ = ["connect", "read_sql_query", "read_sql_table", "to_sql"]
1719

@@ -32,6 +34,9 @@ def inner(*args: Any, **kwargs: Any) -> Any:
3234
)
3335
return func(*args, **kwargs)
3436

37+
inner.__doc__ = func.__doc__
38+
inner.__name__ = func.__name__
39+
inner.__setattr__("__signature__", inspect.signature(func)) # pylint: disable=no-member
3540
return inner # type: ignore
3641

3742

@@ -281,6 +286,7 @@ def read_sql_table(
281286

282287

283288
@_check_for_pyodbc
289+
@apply_configs
284290
def to_sql(
285291
df: pd.DataFrame,
286292
con: "pyodbc.Connection",
@@ -291,6 +297,7 @@ def to_sql(
291297
dtype: Optional[Dict[str, str]] = None,
292298
varchar_lengths: Optional[Dict[str, int]] = None,
293299
use_column_names: bool = False,
300+
chunksize: int = 200,
294301
) -> None:
295302
"""Write records stored in a DataFrame into Microsoft SQL Server.
296303
@@ -319,6 +326,8 @@ def to_sql(
319326
If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query.
320327
E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be
321328
inserted into the database columns `col1` and `col3`.
329+
chunksize: int
330+
Number of rows which are inserted with each SQL query. Defaults to inserting 200 rows per query.
322331
323332
Returns
324333
-------
@@ -357,15 +366,18 @@ def to_sql(
357366
)
358367
if index:
359368
df.reset_index(level=df.index.names, inplace=True)
360-
placeholders: str = ", ".join(["?"] * len(df.columns))
369+
column_placeholders: str = ", ".join(["?"] * len(df.columns))
361370
table_identifier = _get_table_identifier(schema, table)
362371
insertion_columns = ""
363372
if use_column_names:
364373
insertion_columns = f"({', '.join(df.columns)})"
365-
sql: str = f"INSERT INTO {table_identifier} {insertion_columns} VALUES ({placeholders})"
366-
_logger.debug("sql: %s", sql)
367-
parameters: List[List[Any]] = _db_utils.extract_parameters(df=df)
368-
cursor.executemany(sql, parameters)
374+
placeholder_parameter_pair_generator = _db_utils.generate_placeholder_parameter_pairs(
375+
df=df, column_placeholders=column_placeholders, chunksize=chunksize
376+
)
377+
for placeholders, parameters in placeholder_parameter_pair_generator:
378+
sql: str = f"INSERT INTO {table_identifier} {insertion_columns} VALUES {placeholders}"
379+
_logger.debug("sql: %s", sql)
380+
cursor.executemany(sql, (parameters,))
369381
con.commit()
370382
except Exception as ex:
371383
con.rollback()

tests/test_config.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import os
3-
from unittest.mock import patch
3+
from unittest.mock import create_autospec, patch
44

55
import boto3
66
import botocore
@@ -9,6 +9,7 @@
99
import pytest
1010

1111
import awswrangler as wr
12+
from awswrangler._config import apply_configs
1213
from awswrangler.s3._fs import open_s3_object
1314

1415
logging.getLogger("awswrangler").setLevel(logging.DEBUG)
@@ -180,3 +181,23 @@ def wrapper(self, **kwarg):
180181
s3obj.write(b"foo")
181182

182183
wr.config.reset()
184+
185+
186+
def test_chunk_size():
187+
expected_chunksize = 123
188+
189+
wr.config.chunksize = expected_chunksize
190+
191+
for function_to_mock in [wr.postgresql.to_sql, wr.mysql.to_sql, wr.sqlserver.to_sql, wr.redshift.to_sql]:
192+
mock = create_autospec(function_to_mock)
193+
apply_configs(mock)(df=None, con=None, table=None, schema=None)
194+
mock.assert_called_with(df=None, con=None, table=None, schema=None, chunksize=expected_chunksize)
195+
196+
expected_chunksize = 456
197+
os.environ["WR_CHUNKSIZE"] = str(expected_chunksize)
198+
wr.config.reset()
199+
200+
for function_to_mock in [wr.postgresql.to_sql, wr.mysql.to_sql, wr.sqlserver.to_sql, wr.redshift.to_sql]:
201+
mock = create_autospec(function_to_mock)
202+
apply_configs(mock)(df=None, con=None, table=None, schema=None)
203+
mock.assert_called_with(df=None, con=None, table=None, schema=None, chunksize=expected_chunksize)

0 commit comments

Comments
 (0)