Skip to content

Commit 5af76e7

Browse files
authored
add auto_transaction connection arg to control whether a transaction is auto-started when creating a cursor (#76)
1 parent 66ff9c6 commit 5af76e7

File tree

3 files changed

+65
-1
lines changed

3 files changed

+65
-1
lines changed

pydataapi/dbapi.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class ConnectArgs(BaseModel):
9696
client: Optional[Any] = None
9797
rollback_exception: Optional[Type[Exception]] = None
9898
rds_client: Optional[Any] = None
99+
auto_transaction: Optional[bool] = True
99100

100101

101102
class Connection:
@@ -113,6 +114,7 @@ def __init__(self, **kwargs: Any) -> None:
113114
client=connect_args.client,
114115
rollback_exception=connect_args.rollback_exception,
115116
rds_client=connect_args.rds_client,
117+
auto_transaction=connect_args.auto_transaction,
116118
)
117119

118120
self.closed = False
@@ -132,7 +134,7 @@ def rollback(self) -> None:
132134
self._data_api._transaction_id = None
133135

134136
def cursor(self) -> 'Cursor':
135-
if not self._data_api.transaction_id:
137+
if not self._data_api.transaction_id and self._data_api.auto_transaction:
136138
self._data_api.begin()
137139
cursor = Cursor(self._data_api)
138140
self.cursors.append(cursor)

pydataapi/pydataapi.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ def __init__(
365365
client: Optional[boto3.session.Session.client] = None,
366366
rollback_exception: Optional[Type[Exception]] = None,
367367
rds_client: Optional[boto3.session.Session.client] = None,
368+
auto_transaction: Optional[bool] = None,
368369
) -> None:
369370
if resource_name:
370371
if resource_arn:
@@ -389,6 +390,7 @@ def __init__(
389390
)
390391
self._transaction_status: Optional[str] = None
391392
self.rollback_exception: Optional[Type[Exception]] = rollback_exception
393+
self._auto_transaction: Optional[bool] = auto_transaction
392394

393395
def __enter__(self) -> "DataAPI":
394396
self.begin()
@@ -418,6 +420,10 @@ def transaction_id(self) -> Optional[str]:
418420
def transaction_status(self) -> Optional[str]:
419421
return self._transaction_status
420422

423+
@property
424+
def auto_transaction(self) -> Optional[bool]:
425+
return self._auto_transaction
426+
421427
def begin(
422428
self, database: Optional[str] = None, schema: Optional[str] = None
423429
) -> str:

tests/pydataapi/test_dbaapi.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,3 +353,59 @@ class CustomError(Exception):
353353
)
354354
raise Exception('error')
355355
second_mocked_client.rollback_transaction.assert_not_called()
356+
357+
358+
def test_execute_select_w_auto_transaction(mocked_client, mocker) -> None:
359+
mocked_client.begin_transaction.return_value = {'transactionId': 'abc'}
360+
mocked_client.execute_statement.return_value = {
361+
'numberOfRecordsUpdated': 0,
362+
'records': [[{'longValue': 1}, {'stringValue': 'cat'}]],
363+
}
364+
data_api = connect(
365+
resource_arn='arn:aws:rds:dummy',
366+
secret_arn='dummy',
367+
database='test',
368+
client=mocked_client,
369+
)
370+
result = data_api.execute("select * from pets")
371+
assert result.rowcount == 1
372+
assert result.fetchone() == [1, 'cat']
373+
assert result.fetchone() is None
374+
assert mocked_client.execute_statement.call_args == mocker.call(
375+
continueAfterTimeout=True,
376+
includeResultMetadata=True,
377+
resourceArn='arn:aws:rds:dummy',
378+
secretArn='dummy',
379+
sql="select * from pets",
380+
database='test',
381+
transactionId='abc',
382+
)
383+
mocked_client.begin_transaction.assert_called_once()
384+
385+
386+
def test_execute_select_wo_auto_transaction(mocked_client, mocker) -> None:
387+
mocked_client.begin_transaction.return_value = {'transactionId': 'abc'}
388+
mocked_client.execute_statement.return_value = {
389+
'numberOfRecordsUpdated': 0,
390+
'records': [[{'longValue': 1}, {'stringValue': 'cat'}]],
391+
}
392+
data_api = connect(
393+
resource_arn='arn:aws:rds:dummy',
394+
secret_arn='dummy',
395+
database='test',
396+
client=mocked_client,
397+
auto_transaction=False,
398+
)
399+
result = data_api.execute("select * from pets")
400+
assert result.rowcount == 1
401+
assert result.fetchone() == [1, 'cat']
402+
assert result.fetchone() is None
403+
assert mocked_client.execute_statement.call_args == mocker.call(
404+
continueAfterTimeout=True,
405+
database='test',
406+
includeResultMetadata=True,
407+
resourceArn='arn:aws:rds:dummy',
408+
secretArn='dummy',
409+
sql='select * from pets',
410+
)
411+
mocked_client.begin_transaction.assert_not_called()

0 commit comments

Comments
 (0)